Pre-trained NFNets with 99% of the accuracy of the official paper

Overview

NFNet Pytorch Implementation

This repo contains pretrained NFNet models F0-F6 with high ImageNet accuracy from the paper High-Performance Large-Scale Image Recognition Without Normalization. The small models are as accurate as an EfficientNet-B7, but train 8.7 times faster. The large models set a new SOTA top-1 accuracy on ImageNet.

NFNet F0 F1 F2 F3 F4 F5 F6+SAM
Top-1 accuracy Brock et al. 83.6 84.7 85.1 85.7 85.9 86.0 86.5
Top-1 accuracy this implementation 82.82 84.63 84.90 85.46 85.66 85.62 TBD

All credits go to the authors of the original paper. This repo is heavily inspired by their nice JAX implementation in the official repository. Visit their repo for citing.

Get started

git clone https://github.com/benjs/nfnets_pytorch.git
pip3 install -r requirements.txt

Download pretrained weights from the official repository and place them in the pretrained folder.

from pretrained import pretrained_nfnet
model_F0 = pretrained_nfnet('pretrained/F0_haiku.npz')
model_F1 = pretrained_nfnet('pretrained/F1_haiku.npz')
# ...

The model variant is automatically derived from the parameter count in the pretrained weights file.

Validate yourself

python3 eval.py --pretrained pretrained/F0_haiku.npz --dataset path/to/imagenet/valset/

You can download the ImageNet validation set from the ILSVRC2012 challenge site after asking for access with, for instance, your .edu mail address.

Scaled weight standardization convolutions in your own model

Simply replace all your nn.Conv2d with WSConv2D and all your nn.ReLU with VPReLU or VPGELU (variance preserving ReLU/GELU).

import torch.nn as nn
from model import WSConv2D, VPReLU, VPGELU

# Simply replace your nn.Conv2d layers
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
 
        self.activation = VPReLU(inplace=True) # or VPGELU
        self.conv0 = WSConv2D(in_channels=128, out_channels=256, kernel_size=1, ...)
        # ...

    def forward(self, x):
      out = self.activation(self.conv0(x))
      # ...

SGD with adaptive gradient clipping in your own model

Simply replace your SGD optimizer with SGD_AGC.

from optim import SGD_AGC

optimizer = SGD_AGC(
        named_params=model.named_parameters(), # Pass named parameters
        lr=1e-3,
        momentum=0.9,
        clipping=0.1, # New clipping parameter
        weight_decay=2e-5, 
        nesterov=True)

It is important to exclude certain layers from clipping or momentum. The authors recommends to exclude the last fully convolutional from clipping and the bias/gain parameters from weight decay:

import re

for group in optimizer.param_groups:
    name = group['name'] 
    
    # Exclude from weight decay
    if len(re.findall('stem.*(bias|gain)|conv.*(bias|gain)|skip_gain', name)) > 0:
        group['weight_decay'] = 0

    # Exclude from clipping
    if name.startswith('linear'):
        group['clipping'] = None

Train your own NFNet

Adjust your desired parameters in default_config.yaml and start training.

python3 train.py --dataset /path/to/imagenet/

There is still some parts missing for complete training from scratch:

  • Multi-GPU training
  • Data augmentations
  • FP16 activations and gradients

Contribute

The implementation is still in an early stage in terms of usability / testing. If you have an idea to improve this repo open an issue, start a discussion or submit a pull request.

Development status

  • Pre-trained NFNet Models
    • F0-F5
    • F6+SAM
    • Scaled weight standardization
    • Squeeze and excite
    • Stochastic depth
    • FP16 activations
  • SGD with unit adaptive gradient clipping (SGD-AGC)
    • Exclude certain layers from weight-decay, clipping
    • FP16 gradients
  • PyPI package
  • PyTorch hub submission
  • Label smoothing loss from Szegedy et al.
  • Training on ImageNet
  • Pre-trained weights
  • Tensorboard support
  • general usability improvements
  • Multi-GPU support
  • Data augmentation
  • Signal propagation plots (from first paper)
Comments
  • ModuleNotFoundError: No module named 'haiku'

    ModuleNotFoundError: No module named 'haiku'

    when i try "python3 eval.py --pretrained pretrained/F0_haiku.npz --dataset ***" i got this error, have you ever met this error? how to fix this?

    opened by Rianusr 2
  • Trained without data augmentation?

    Trained without data augmentation?

    Thanks for the great work on the pytorch implementation of NFNet! The accuracies achieved by this implementation are pretty impressive also and I am wondering if these training results were simply derived from the training script, that is, without data augmentation.

    opened by nandi-zhang 2
  • from_pretrained_haiku

    from_pretrained_haiku

    https://github.com/benjs/nfnets_pytorch/blob/7b4d1cc701c7de4ee273ded01ce21cbdb1e60c48/nfnets/pretrained.py#L90

    model = from_pretrained_haiku(args.pretrained)

    where is 'from_pretrained_haiku' method?

    opened by vkmavani 0
  • About WSconv2d

    About WSconv2d

    I see the authoe's code, I find his WSconv2d pad_mod is 'same'. Pytorch's conv2d dono't have pad_mode, and I think your padding should greater 0, but I find your padding always be 0. I want to know why?

    I see you train.py your learning rate is constant, why? Thank you!

    opened by fancyshun 3
  • AveragePool

    AveragePool

    Hi, noticed that the AveragePool ('pool' layer) is not used in forward function. Instead, forward uses torch.mean. Removing the layer doesn't change pooling behavior. I tried using this model as a feature extractor and was a bit confused for a moment.

    opened by bogdankjastrzebski 1
Releases(v0.0.1)
Owner
Benjamin Schmidt
Engineering Student
Benjamin Schmidt
Deep Q-learning for playing chrome dino game

[PYTORCH] Deep Q-learning for playing Chrome Dino

Viet Nguyen 68 Dec 05, 2022
EMNLP 2021 paper The Devil is in the Detail: Simple Tricks Improve Systematic Generalization of Transformers.

Codebase for training transformers on systematic generalization datasets. The official repository for our EMNLP 2021 paper The Devil is in the Detail:

Csordás Róbert 57 Nov 21, 2022
Perception-aware multi-sensor fusion for 3D LiDAR semantic segmentation (ICCV 2021)

Perception-Aware Multi-Sensor Fusion for 3D LiDAR Semantic Segmentation (ICCV 2021) [中文|EN] 概述 本工作主要探索一种高效的多传感器(激光雷达和摄像头)融合点云语义分割方法。现有的多传感器融合方法主要将点云投影

ICE 126 Dec 30, 2022
MAUS: A Dataset for Mental Workload Assessment Using Wearable Sensor - Baseline system

MAUS: A Dataset for Mental Workload Assessment Using Wearable Sensor - Baseline system Getting started To start working on this assignment, you should

2 Aug 06, 2022
Differentiable Factor Graph Optimization for Learning Smoothers @ IROS 2021

Differentiable Factor Graph Optimization for Learning Smoothers Overview Status Setup Datasets Training Evaluation Acknowledgements Overview Code rele

Brent Yi 60 Nov 14, 2022
SpecAugmentPyTorch - A Pytorch (support batch and channel) implementation of GoogleBrain's SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition

SpecAugment An implementation of SpecAugment for Pytorch How to use Install pytorch, version=1.9.0 (new feature (torch.Tensor.take_along_dim) is used

IMLHF 3 Oct 11, 2022
Implementation of the "PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences" paper.

PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences Introduction Point cloud sequences are irregular and unordered in the spatial dimen

Hehe Fan 63 Dec 09, 2022
Uni-Fold: Training your own deep protein-folding models.

Uni-Fold: Training your own deep protein-folding models. This package provides and implementation of a trainable, Transformer-based deep protein foldi

DeepModeling 88 Jan 03, 2023
Nsdf: A mesh SDF with just some code we can directly paste into our raymarcher

nsdf Representing SDFs of arbitrary meshes has been a bit tricky so far. Express

Jan Ivanecky 5 Feb 18, 2022
Unofficial implementation of Point-Unet: A Context-Aware Point-Based Neural Network for Volumetric Segmentation

Point-Unet This is an unofficial implementation of the MICCAI 2021 paper Point-Unet: A Context-Aware Point-Based Neural Network for Volumetric Segment

Namt0d 9 Dec 07, 2022
Code for MarioNette: Self-Supervised Sprite Learning, in NeurIPS 2021

MarioNette | Webpage | Paper | Video MarioNette: Self-Supervised Sprite Learning Dmitriy Smirnov, Michaël Gharbi, Matthew Fisher, Vitor Guizilini, Ale

Dima Smirnov 28 Nov 18, 2022
PyTorch Implementation of CvT: Introducing Convolutions to Vision Transformers

CvT: Introducing Convolutions to Vision Transformers Pytorch implementation of CvT: Introducing Convolutions to Vision Transformers Usage: img = torch

Rishikesh (ऋषिकेश) 193 Jan 03, 2023
Character-Input - Create a program that asks the user to enter their name and their age

Character-Input Create a program that asks the user to enter their name and thei

PyLaboratory 0 Feb 06, 2022
Code for our TKDE paper "Understanding WeChat User Preferences and “Wow” Diffusion"

wechat-wow-analysis Understanding WeChat User Preferences and “Wow” Diffusion. Fanjin Zhang, Jie Tang, Xueyi Liu, Zhenyu Hou, Yuxiao Dong, Jing Zhang,

18 Sep 16, 2022
This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust.

Demo BERT ONNX pipeline written in rust This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust. R

Xavier Tao 14 Dec 17, 2022
Optimal space decomposition based-product quantization for approximate nearest neighbor search

Optimal space decomposition based-product quantization for approximate nearest neighbor search Abstract Product quantization(PQ) is an effective neare

Mylove 1 Nov 19, 2021
Companion code for the paper Theoretical characterization of uncertainty in high-dimensional linear classification

Companion code for the paper Theoretical characterization of uncertainty in high-dimensional linear classification Usage The required packages are lis

0 Feb 07, 2022
A rule-based log analyzer & filter

Flog 一个根据规则集来处理文本日志的工具。 前言 在日常开发过程中,由于缺乏必要的日志规范,导致很多人乱打一通,一个日志文件夹解压缩后往往有几十万行。 日志泛滥会导致信息密度骤减,给排查问题带来了不小的麻烦。 以前都是用grep之类的工具先挑选出有用的,再逐条进行排查,费时费力。在忍无可忍之后决

上山打老虎 9 Jun 23, 2022
Retina blood vessel segmentation with a convolutional neural network

Retina blood vessel segmentation with a convolution neural network (U-net) This repository contains the implementation of a convolutional neural netwo

Orobix 1.2k Jan 06, 2023
Bravia core script for python

Bravia-Core-Script You need to have a mandatory account If this L3 does not work, try another L3. enjoy

5 Dec 26, 2021