Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition

Related tags

Deep LearningZenNAS
Overview

License arXiv

Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition

Zen-NAS is a lightning fast, training-free Neural Architecture Searching (NAS) algorithm for automatically designing deep neural networks with high prediction accuracy and high inference speed on GPU and mobile device.

This repository contains pre-trained models, a mini framework for zero-shot NAS searching, and scripts to reproduce our results. You can even customize your own search space and develop a new zero-shot NAS proxy using our pipeline. Contributions are welcomed.

The arXiv version of our paper is available from here. To appear in ICCV 2021. bibtex

How Fast

Using 1 GPU searching for 12 hours, ZenNAS is able to design networks of ImageNet top-1 accuracy comparable to EfficientNet-B5 (~83.6%) while inference speed 4.9x times faster on V100, 10x times faster on NVIDIA T4, 1.6x times faster on Google Pixel2.

Inference Speed

Compare to Other Zero-Shot NAS Proxies on CIFAR-10/100

We use the ResNet-like search space and search for models within parameter budget 1M. All models are searched by the same evolutionary strategy, trained on CIFAR-10/100 for 1440 epochs with auto-augmentation, cosine learning rate decay, weight decay 5e-4. We report the top-1 accuracies in the following table:

proxy CIFAR-10 CIFAR-100
Zen-NAS 96.2% 80.1%
FLOPs 93.1% 64.7%
grad-norm 92.8% 65.4%
synflow 95.1% 75.9%
TE-NAS 96.1% 77.2%
NASWOT 96.0% 77.5%
Random 93.5% 71.1%

Please check our paper for more details.

Pre-trained Models

We provided pre-trained models on ImageNet and CIFAR-10/CIFAR-100.

ImageNet Models

model resolution # params FLOPs Top-1 Acc V100 T4 Pixel2
zennet_imagenet1k_flops400M_SE_res224 224 5.7M 410M 78.0% 0.25 0.39 87.9
zennet_imagenet1k_flops600M_SE_res224 224 7.1M 611M 79.1% 0.36 0.52 128.6
zennet_imagenet1k_flops900M_SE_res224 224 19.4M 934M 80.8% 0.55 0.55 215.7
zennet_imagenet1k_latency01ms_res224 224 30.1M 1.7B 77.8% 0.1 0.08 181.7
zennet_imagenet1k_latency02ms_res224 224 49.7M 3.4B 80.8% 0.2 0.15 357.4
zennet_imagenet1k_latency03ms_res224 224 85.4M 4.8B 81.5% 0.3 0.20 517.0
zennet_imagenet1k_latency05ms_res224 224 118M 8.3B 82.7% 0.5 0.30 798.7
zennet_imagenet1k_latency08ms_res224 224 183M 13.9B 83.0% 0.8 0.57 1365
zennet_imagenet1k_latency12ms_res224 224 180M 22.0B 83.6% 1.2 0.85 2051
EfficientNet-B3 300 12.0M 1.8B 81.1% 1.12 1.86 569.3
EfficientNet-B5 456 30.0M 9.9B 83.3% 4.5 7.0 2580
EfficientNet-B6 528 43M 19.0B 84.0% 7.64 12.3 4288
  • 'V100' is the inference latency on NVIDIA V100 in milliseconds, benchmarked at batch size 64, float16.
  • 'T4' is the inference latency on NVIDIA T4 in milliseconds, benchmarked at batch size 64, TensorRT INT8.
  • 'Pixel2' is the inference latency on Google Pixel2 in milliseconds, benchmarked at single image.

CIFAR-10/CIFAR-100 Models

model resolution # params FLOPs Top-1 Acc
zennet_cifar10_model_size05M_res32 32 0.5M 140M 96.2%
zennet_cifar10_model_size1M_res32 32 1.0M 162M 96.2%
zennet_cifar10_model_size2M_res32 32 2.0M 487M 97.5%
zennet_cifar100_model_size05M_res32 32 0.5M 140M 79.9%
zennet_cifar100_model_size1M_res32 32 1.0M 162M 80.1%
zennet_cifar100_model_size2M_res32 32 2.0M 487M 84.4%

Reproduce Paper Experiments

System Requirements

  • PyTorch >= 1.5, Python >= 3.7
  • By default, ImageNet dataset is stored under ~/data/imagenet; CIFAR-10/CIFAR-100 is stored under ~/data/pytorch_cifar10 or ~/data/pytorch_cifar100
  • Pre-trained parameters are cached under ~/.cache/pytorch/checkpoints/zennet_pretrained

Evaluate pre-trained models on ImageNet and CIFAR-10/100

To evaluate the pre-trained model on ImageNet using GPU 0:

python val.py --fp16 --gpu 0 --arch ${zennet_model_name}

where ${zennet_model_name} should be replaced by a valid ZenNet model name. The complete list of model names can be found in 'Pre-trained Models' section.

To evaluate the pre-trained model on CIFAR-10 or CIFAR-100 using GPU 0:

python val_cifar.py --dataset cifar10 --gpu 0 --arch ${zennet_model_name}

To create a ZenNet in your python code:

gpu=0
model = ZenNet.get_ZenNet(opt.arch, pretrained=True)
torch.cuda.set_device(gpu)
torch.backends.cudnn.benchmark = True
model = model.cuda(gpu)
model = model.half()
model.eval()

Searching on CIFAR-10/100

Searching for CIFAR-10/100 models with budget params < 1M , using different zero-shot proxies:

'''bash scripts/Flops_NAS_cifar_params1M.sh scripts/GradNorm_NAS_cifar_params1M.sh scripts/NASWOT_NAS_cifar_params1M.sh scripts/Params_NAS_cifar_params1M.sh scripts/Random_NAS_cifar_params1M.sh scripts/Syncflow_NAS_cifar_params1M.sh scripts/TE_NAS_cifar_params1M.sh scripts/Zen_NAS_cifar_params1M.sh '''

Searching on ImageNet

Searching for ImageNet models, with latency budget on NVIDIA V100 from 0.1 ms/image to 1.2 ms/image at batch size 64 FP16:

scripts/Zen_NAS_ImageNet_latency0.1ms.sh
scripts/Zen_NAS_ImageNet_latency0.2ms.sh
scripts/Zen_NAS_ImageNet_latency0.3ms.sh
scripts/Zen_NAS_ImageNet_latency0.5ms.sh
scripts/Zen_NAS_ImageNet_latency0.8ms.sh
scripts/Zen_NAS_ImageNet_latency1.2ms.sh

Searching for ImageNet models, with FLOPs budget from 400M to 800M:

scripts/Zen_NAS_ImageNet_flops400M.sh
scripts/Zen_NAS_ImageNet_flops600M.sh
scripts/Zen_NAS_ImageNet_flops800M.sh

Customize Your Own Search Space and Zero-Shot Proxy

The masternet definition is stored in "Masternet.py". The masternet takes in a structure string and parses it into a PyTorch nn.Module object. The structure string defines the layer structure which is implemented in "PlainNet/*.py" files. For example, in "PlainNet/SuperResK1KXK1.py", we defined SuperResK1K3K1 block, which consists of multiple layers of ResNet blocks. To define your own block, e.g. ABC_Block, first implement "PlainNet/ABC_Block.py". Then in "PlainNet/__init__.py", after the last line, append the following lines to register the new block definition:

from PlainNet import ABC_Block
_all_netblocks_dict_ = ABC_Block.register_netblocks_dict(_all_netblocks_dict_)

After the above registration call, the PlainNet module is able to parse your customized block from structure string.

The search space definitions are stored in SearchSpace/*.py. The important function is

gen_search_space(block_list, block_id)

block_list is a list of super-blocks parsed by the masternet. block_id is the index of the block in block_list which will be replaced later by a mutated block This function must return a list of mutated blocks.

The zero-shot proxies are implemented in "ZeroShotProxy/*.py". The evolutionary algorithm is implemented in "evolution_search.py". "analyze_model.py" prints the FLOPs and model size of the given network. "benchmark_network_latency.py" measures the network inference latency. "train_image_classification.py" implements SGD gradient training and "ts_train_image_classification.py" implements teacher-student distillation.

FAQ

Q: Why it is so slow when searching with latency constraints? A: Most of the time is spent in benchmarking the network latency. We use a latency predictor in our paper, which is not released.

Major Contributors

How to Cite This Work

Ming Lin, Pichao Wang, Zhenhong Sun, Hesen Chen, Xiuyu Sun, Qi Qian, Hao Li, Rong Jin. Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition. 2021 IEEE/CVF International Conference on Computer Vision (ICCV 2021).

@inproceedings{ming_zennas_iccv2021,
  author    = {Ming Lin and Pichao Wang and Zhenhong Sun and Hesen Chen and Xiuyu Sun and Qi Qian and Hao Li and Rong Jin},
  title     = {Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition},
  booktitle = {2021 IEEE/CVF International Conference on Computer Vision, {ICCV} 2021},  
  year      = {2021},
}

Open Source

A few files in this repository are modified from the following open-source implementations:

https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
https://github.com/VITA-Group/TENAS
https://github.com/SamsungLabs/zero-cost-nas
https://github.com/BayesWatch/nas-without-training
https://github.com/rwightman/gen-efficientnet-pytorch
https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html

Copyright

Copyright (C) 2010-2021 Alibaba Group Holding Limited.

Comments
  • multi-gpu training

    multi-gpu training

    Training for 480 epoch on single GPU will take quite a lot of time. So I wanted to check if multi GPU training is possible in the current code. I tried --dist_mode with mpi and auto are not supported as global_utils.AutoGPU() is not defined. But looks like --dist_mode=horovod is supported. Can it be used to do multi GPU training? Is it possible to share sample command to do multi-gpu training?

    Btw it took around 2.83 hrs to complete 1 epoch on my GTX 1080-TI, does this sound reasonable or too high? Typically I have seen on my machine (4 x GTX1080) around 15 mins per epoch for EfficientNet-ES using Ross wightman's timm repo.

    opened by soyebn 9
  • Deeper and wider network has higher accuracy?

    Deeper and wider network has higher accuracy?

    From Figure 2. in the paper, it can be seen that deeper and wider network has higher Zen-Score. And Zen-Score positively correlates with model accuracy. So Deeper and wider network has higher accuracy, which is a well-known principle. Then whats' the meaning of Zen-Score?

    opened by buaabai 6
  • Hi MingLin

    Hi MingLin

    Your proposed Zen-NAS is a very efficient way to search for neural network structures. I read your article and GitHub code carefully, and did my own search on your code, but one thing I found is that the network structure searched almost always repeats more times the deeper the network is the network block, and the first few layers of the network block are repeated once, for example, I used your code to search the structure of MNas (the search space has been changed according to MNas), MNas0.35 optimal structure: 
    

    SuperConvK3BNRELU(3,16,2,1)SuperResMnasV1K3(16,8,1,16,1)SuperResMnasV3K3(8,8,2,8,3)SuperResMnasV3K5(8,16,2,8,3) SuperResMnasV6K5(16,32,2,16,3)SuperResMnasV6K3(32,32,1,32,2)SuperResMnasV6K5(32,64,2,32,4)SuperResMnasV6K3(64,112,1,64,1) SuperConvK1BNRELU(112,1280,1,1) but I searched with your architecture and the structure is as follows SuperConvK3BNRELU(3,8,2,1)SuperResMnasV1K3(8,8,1,8,1)SuperResMnasV3K5(8,16,2,8,1)SuperResMnasV3K5(16,24,2,8,1)SuperResMnasV3K5(24,64,2,40,1)SuperResMnasV3K5(64,24,1,48,1)SuperResMnasV3K5(24,64,2,176,4)SuperResMnasV3K5(64,48,1,256,5)SuperConvK1BNRELU(48,2048,1, 1) so the search out of the structure compared to the original structure is not very good, the structure is still the problem mentioned above, the search out of the network shallow block duplication for 1, only the deeper network has block duplication, also tested your code in the Flops400M,600M,900M, found the same problem, this is why?

    opened by billbig 5
  • is it possible for you to release the retrain log file?

    is it possible for you to release the retrain log file?

    hi, @MingLin-home , In paper and official code, retrain the searched model by the feature loss cost too much resource, the 1.2ms latency model is very difficult to train completely end to end. so whether the series model‘s log can be released?

    image
    opened by aptsunny 4
  • NAS-Bench-201

    NAS-Bench-201

    Hello,

    As I remember from the paper, your method works on Vanilla CNN. However, in algorithm 1, you just mentioned that residual connections are deleted.

    It confused me a little bit, and I do not know your method can be applied on any CNN without residual connection or it only can be applied on Vanilla CNN.

    Can I use your code on benchmarks such as NAS-Bench-201?

    opened by farhad-dalirani 3
  • Some simple questions

    Some simple questions

    Dear researcher, Hello
    I ran into some problems while running the code that I couldn't solve when I wanted to test CIFAR10
    (python val_cifar.py --dataset cifar10 --gpu 0 --arch zennet_cifar10_model_size05M_res32) image

    RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.image

    The version I use is Torch1.7.1 Torchvision0.8.2 cudatoolkit10.2.89. May I ask if this is a problem with my operating environment? Or do I need to make changes in the code? Thanks a million, dear researcher.

    opened by fly2tortoise 3
  • Same search space used in search_space_IDW_fixfc.py as well as in search_space_XXBL.py

    Same search space used in search_space_IDW_fixfc.py as well as in search_space_XXBL.py

    First of all, thanks a lot for releasing the code for such a nice work. I have few doubts,

    1. The seach_space_block_type_list_list in both files, search_space_IDW_fixfc.py as well as search_space_XXBL.py are same. Is it intended?

    2. The script, Zen_NAS_ImageNet_flops400M.sh and Zen_NAS_ImageNet_latency0.2ms use same search space. But the paper mentions Zen_NAS_ImageNet_flops400M.sh using MB block and Zen_NAS_ImageNet_latency0.2ms uses botn block like Resnet50.

    Thanks again.

    opened by soyebn 3
  • A training question about GENet

    A training question about GENet

    Hi, everyone. I'm currently trying to reproduce your previous work "Neural Architecture Design for GPU-Efficient Networks", here is the repo link. After determining the model structure, such as "GENet_large", "GENet_small", etc. Can we refer to the training script(train_image_classification.py) in this repo to train the model and get the effect consistent with the paper description?

    opened by FisherWY 2
  • How to do `entropy_forward` for CSP network?

    How to do `entropy_forward` for CSP network?

    My block look like this.

    BottleneckCSP(
      (cv1): Conv(
        (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Mish()
      )
      (cv2): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (cv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (cv4): Conv(
        (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Mish()
      )
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): Mish()
      (m): Sequential(
        (0): Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): Mish()
          )
          (cv2): Conv(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): Mish()
          )
        )
        (1): Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): Mish()
          )
          (cv2): Conv(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): Mish()
          )
        )
      )
    )
    

    forward(self, x)

    def forward(self, x):
            d = self.m(self.cv1(x))
            y1 = self.cv3(d)
            y2 = self.cv2(x)
            return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
    
    opened by 1chimaruGin 2
  • Use of NTK Condition Number in Combination with ZEN Score.

    Use of NTK Condition Number in Combination with ZEN Score.

    Hello MingLin,

    I found your NAS approach to be very interesting.

    Did you ever try to combine the ZEN-Score with the NTK-score like the TE-NAS paper did? The TE-NAS paper suggests that combining the NTK-Score with a expressivity measure would improve the overall performance.

    Thanks, SXK

    opened by SXKJames 2
  • Does ZenScore Support MobileNet-like Search Space?

    Does ZenScore Support MobileNet-like Search Space?

    Thank you for your great work.

    I have tried zenscore and get great kendall tau in plain network.

    I wonder if zenscore can apply to MobileNet-like Search Space?

    Looking forward to your reply.

    opened by pprp 2
  • GENet how to use soft-labels

    GENet how to use soft-labels

    i follow the paper "Neural Architecture Design for GPU-Efficient Networks", and see "We also use ResNet-152 as teacher network to get soft-labels. The soft-label loss and the true-label loss are weighted 1:1." at page 8 how to use soft-labels,teach genet. like knowledge distillation? have any code in this repo?

    opened by bigcash 6
Releases(iccv-2021-v1.1)
Owner
A vision team from Alibaba
Multilingual Image Captioning

Multilingual Image Captioning Authors: Bhavitvya Malik, Gunjan Chhablani Demo Link: https://huggingface.co/spaces/flax-community/multilingual-image-ca

Gunjan Chhablani 32 Nov 25, 2022
Guiding evolutionary strategies by (inaccurate) differentiable robot simulators @ NeurIPS, 4th Robot Learning Workshop

Guiding Evolutionary Strategies by Differentiable Robot Simulators In recent years, Evolutionary Strategies were actively explored in robotic tasks fo

Vladislav Kurenkov 4 Dec 14, 2021
Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning.

Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning. Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive

<a href=[email protected](SZ)"> 7 Dec 16, 2021
KGDet: Keypoint-Guided Fashion Detection (AAAI 2021)

KGDet: Keypoint-Guided Fashion Detection (AAAI 2021) This is an official implementation of the AAAI-2021 paper "KGDet: Keypoint-Guided Fashion Detecti

Qian Shenhan 35 Dec 29, 2022
yolov5目标检测模型的知识蒸馏(基于响应的蒸馏)

代码地址: https://github.com/Sharpiless/yolov5-knowledge-distillation 教师模型: python train.py --weights weights/yolov5m.pt \ --cfg models/yolov5m.ya

52 Dec 04, 2022
A lossless neural compression framework built on top of JAX.

Kompressor Branch CI Coverage main (active) main development A neural compression framework built on top of JAX. Install setup.py assumes a compatible

Rosalind Franklin Institute 2 Mar 14, 2022
C3D is a modified version of BVLC caffe to support 3D ConvNets.

C3D C3D is a modified version of BVLC caffe to support 3D convolution and pooling. The main supporting features include: Training or fine-tuning 3D Co

Meta Archive 1.1k Nov 14, 2022
Visualizer using audio and semantic analysis to explore BigGAN (Brock et al., 2018) latent space.

BigGAN Audio Visualizer Description This visualizer explores BigGAN (Brock et al., 2018) latent space by using pitch/tempo of an audio file to generat

Rush Kapoor 2 Nov 21, 2022
InsightFace: 2D and 3D Face Analysis Project on MXNet and PyTorch

InsightFace: 2D and 3D Face Analysis Project on MXNet and PyTorch

Deep Insight 13.2k Jan 06, 2023
Voice assistant - Voice assistant with python

🌐 Python Voice Assistant 🌵 - User's greeting 🌵 - Writing tasks to todo-list ?

PythonToday 10 Dec 26, 2022
Tracing Versus Freehand for Evaluating Computer-Generated Drawings (SIGGRAPH 2021)

Tracing Versus Freehand for Evaluating Computer-Generated Drawings (SIGGRAPH 2021) Zeyu Wang, Sherry Qiu, Nicole Feng, Holly Rushmeier, Leonard McMill

Zach Zeyu Wang 23 Dec 09, 2022
PyTorch implementation of "PatchGame: Learning to Signal Mid-level Patches in Referential Games" to appear in NeurIPS 2021

PatchGame: Learning to Signal Mid-level Patches in Referential Games This repository is the official implementation of the paper - "PatchGame: Learnin

Kamal Gupta 22 Mar 16, 2022
Fake videos detection by tracing the source using video hashing retrieval.

Vision Transformer Based Video Hashing Retrieval for Tracing the Source of Fake Videos 🎉️ 📜 Directory Introduction VTL Trace Samples and Acc of Hash

56 Dec 22, 2022
darija <-> english dictionary

darija-dictionary Having advanced IT solutions that are well adapted to the Moroccan context passes inevitably through understanding Moroccan dialect.

DODa 102 Jan 01, 2023
Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21)

Learning Structural Edits via Incremental Tree Transformations Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21) 1.

NeuLab 40 Dec 23, 2022
Released code for Objects are Different: Flexible Monocular 3D Object Detection, CVPR21

MonoFlex Released code for Objects are Different: Flexible Monocular 3D Object Detection, CVPR21. Work in progress. Installation This repo is tested w

Yunpeng 169 Dec 06, 2022
Code for paper Novel View Synthesis via Depth-guided Skip Connections

Novel View Synthesis via Depth-guided Skip Connections Code for paper Novel View Synthesis via Depth-guided Skip Connections @InProceedings{Hou_2021_W

8 Mar 14, 2022
Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".

AST: Audio Spectrogram Transformer Introduction Citing Getting Started ESC-50 Recipe Speechcommands Recipe AudioSet Recipe Pretrained Models Contact I

Yuan Gong 603 Jan 07, 2023
Official code for 'Weakly-supervised Video Anomaly Detection with Robust Temporal Feature Magnitude Learning' [ICCV 2021]

RTFM This repo contains the Pytorch implementation of our paper: Weakly-supervised Video Anomaly Detection with Robust Temporal Feature Magnitude Lear

Yu Tian 242 Jan 08, 2023
Paddle Graph Learning (PGL) is an efficient and flexible graph learning framework based on PaddlePaddle

DOC | Quick Start | 中文 Breaking News !! 🔥 🔥 🔥 OGB-LSC KDD CUP 2021 winners announced!! (2021.06.17) Super excited to announce our PGL team won TWO

1.5k Jan 06, 2023