Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition

Related tags

Deep LearningZenNAS
Overview

License arXiv

Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition

Zen-NAS is a lightning fast, training-free Neural Architecture Searching (NAS) algorithm for automatically designing deep neural networks with high prediction accuracy and high inference speed on GPU and mobile device.

This repository contains pre-trained models, a mini framework for zero-shot NAS searching, and scripts to reproduce our results. You can even customize your own search space and develop a new zero-shot NAS proxy using our pipeline. Contributions are welcomed.

The arXiv version of our paper is available from here. To appear in ICCV 2021. bibtex

How Fast

Using 1 GPU searching for 12 hours, ZenNAS is able to design networks of ImageNet top-1 accuracy comparable to EfficientNet-B5 (~83.6%) while inference speed 4.9x times faster on V100, 10x times faster on NVIDIA T4, 1.6x times faster on Google Pixel2.

Inference Speed

Compare to Other Zero-Shot NAS Proxies on CIFAR-10/100

We use the ResNet-like search space and search for models within parameter budget 1M. All models are searched by the same evolutionary strategy, trained on CIFAR-10/100 for 1440 epochs with auto-augmentation, cosine learning rate decay, weight decay 5e-4. We report the top-1 accuracies in the following table:

proxy CIFAR-10 CIFAR-100
Zen-NAS 96.2% 80.1%
FLOPs 93.1% 64.7%
grad-norm 92.8% 65.4%
synflow 95.1% 75.9%
TE-NAS 96.1% 77.2%
NASWOT 96.0% 77.5%
Random 93.5% 71.1%

Please check our paper for more details.

Pre-trained Models

We provided pre-trained models on ImageNet and CIFAR-10/CIFAR-100.

ImageNet Models

model resolution # params FLOPs Top-1 Acc V100 T4 Pixel2
zennet_imagenet1k_flops400M_SE_res224 224 5.7M 410M 78.0% 0.25 0.39 87.9
zennet_imagenet1k_flops600M_SE_res224 224 7.1M 611M 79.1% 0.36 0.52 128.6
zennet_imagenet1k_flops900M_SE_res224 224 19.4M 934M 80.8% 0.55 0.55 215.7
zennet_imagenet1k_latency01ms_res224 224 30.1M 1.7B 77.8% 0.1 0.08 181.7
zennet_imagenet1k_latency02ms_res224 224 49.7M 3.4B 80.8% 0.2 0.15 357.4
zennet_imagenet1k_latency03ms_res224 224 85.4M 4.8B 81.5% 0.3 0.20 517.0
zennet_imagenet1k_latency05ms_res224 224 118M 8.3B 82.7% 0.5 0.30 798.7
zennet_imagenet1k_latency08ms_res224 224 183M 13.9B 83.0% 0.8 0.57 1365
zennet_imagenet1k_latency12ms_res224 224 180M 22.0B 83.6% 1.2 0.85 2051
EfficientNet-B3 300 12.0M 1.8B 81.1% 1.12 1.86 569.3
EfficientNet-B5 456 30.0M 9.9B 83.3% 4.5 7.0 2580
EfficientNet-B6 528 43M 19.0B 84.0% 7.64 12.3 4288
  • 'V100' is the inference latency on NVIDIA V100 in milliseconds, benchmarked at batch size 64, float16.
  • 'T4' is the inference latency on NVIDIA T4 in milliseconds, benchmarked at batch size 64, TensorRT INT8.
  • 'Pixel2' is the inference latency on Google Pixel2 in milliseconds, benchmarked at single image.

CIFAR-10/CIFAR-100 Models

model resolution # params FLOPs Top-1 Acc
zennet_cifar10_model_size05M_res32 32 0.5M 140M 96.2%
zennet_cifar10_model_size1M_res32 32 1.0M 162M 96.2%
zennet_cifar10_model_size2M_res32 32 2.0M 487M 97.5%
zennet_cifar100_model_size05M_res32 32 0.5M 140M 79.9%
zennet_cifar100_model_size1M_res32 32 1.0M 162M 80.1%
zennet_cifar100_model_size2M_res32 32 2.0M 487M 84.4%

Reproduce Paper Experiments

System Requirements

  • PyTorch >= 1.5, Python >= 3.7
  • By default, ImageNet dataset is stored under ~/data/imagenet; CIFAR-10/CIFAR-100 is stored under ~/data/pytorch_cifar10 or ~/data/pytorch_cifar100
  • Pre-trained parameters are cached under ~/.cache/pytorch/checkpoints/zennet_pretrained

Evaluate pre-trained models on ImageNet and CIFAR-10/100

To evaluate the pre-trained model on ImageNet using GPU 0:

python val.py --fp16 --gpu 0 --arch ${zennet_model_name}

where ${zennet_model_name} should be replaced by a valid ZenNet model name. The complete list of model names can be found in 'Pre-trained Models' section.

To evaluate the pre-trained model on CIFAR-10 or CIFAR-100 using GPU 0:

python val_cifar.py --dataset cifar10 --gpu 0 --arch ${zennet_model_name}

To create a ZenNet in your python code:

gpu=0
model = ZenNet.get_ZenNet(opt.arch, pretrained=True)
torch.cuda.set_device(gpu)
torch.backends.cudnn.benchmark = True
model = model.cuda(gpu)
model = model.half()
model.eval()

Searching on CIFAR-10/100

Searching for CIFAR-10/100 models with budget params < 1M , using different zero-shot proxies:

'''bash scripts/Flops_NAS_cifar_params1M.sh scripts/GradNorm_NAS_cifar_params1M.sh scripts/NASWOT_NAS_cifar_params1M.sh scripts/Params_NAS_cifar_params1M.sh scripts/Random_NAS_cifar_params1M.sh scripts/Syncflow_NAS_cifar_params1M.sh scripts/TE_NAS_cifar_params1M.sh scripts/Zen_NAS_cifar_params1M.sh '''

Searching on ImageNet

Searching for ImageNet models, with latency budget on NVIDIA V100 from 0.1 ms/image to 1.2 ms/image at batch size 64 FP16:

scripts/Zen_NAS_ImageNet_latency0.1ms.sh
scripts/Zen_NAS_ImageNet_latency0.2ms.sh
scripts/Zen_NAS_ImageNet_latency0.3ms.sh
scripts/Zen_NAS_ImageNet_latency0.5ms.sh
scripts/Zen_NAS_ImageNet_latency0.8ms.sh
scripts/Zen_NAS_ImageNet_latency1.2ms.sh

Searching for ImageNet models, with FLOPs budget from 400M to 800M:

scripts/Zen_NAS_ImageNet_flops400M.sh
scripts/Zen_NAS_ImageNet_flops600M.sh
scripts/Zen_NAS_ImageNet_flops800M.sh

Customize Your Own Search Space and Zero-Shot Proxy

The masternet definition is stored in "Masternet.py". The masternet takes in a structure string and parses it into a PyTorch nn.Module object. The structure string defines the layer structure which is implemented in "PlainNet/*.py" files. For example, in "PlainNet/SuperResK1KXK1.py", we defined SuperResK1K3K1 block, which consists of multiple layers of ResNet blocks. To define your own block, e.g. ABC_Block, first implement "PlainNet/ABC_Block.py". Then in "PlainNet/__init__.py", after the last line, append the following lines to register the new block definition:

from PlainNet import ABC_Block
_all_netblocks_dict_ = ABC_Block.register_netblocks_dict(_all_netblocks_dict_)

After the above registration call, the PlainNet module is able to parse your customized block from structure string.

The search space definitions are stored in SearchSpace/*.py. The important function is

gen_search_space(block_list, block_id)

block_list is a list of super-blocks parsed by the masternet. block_id is the index of the block in block_list which will be replaced later by a mutated block This function must return a list of mutated blocks.

The zero-shot proxies are implemented in "ZeroShotProxy/*.py". The evolutionary algorithm is implemented in "evolution_search.py". "analyze_model.py" prints the FLOPs and model size of the given network. "benchmark_network_latency.py" measures the network inference latency. "train_image_classification.py" implements SGD gradient training and "ts_train_image_classification.py" implements teacher-student distillation.

FAQ

Q: Why it is so slow when searching with latency constraints? A: Most of the time is spent in benchmarking the network latency. We use a latency predictor in our paper, which is not released.

Major Contributors

How to Cite This Work

Ming Lin, Pichao Wang, Zhenhong Sun, Hesen Chen, Xiuyu Sun, Qi Qian, Hao Li, Rong Jin. Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition. 2021 IEEE/CVF International Conference on Computer Vision (ICCV 2021).

@inproceedings{ming_zennas_iccv2021,
  author    = {Ming Lin and Pichao Wang and Zhenhong Sun and Hesen Chen and Xiuyu Sun and Qi Qian and Hao Li and Rong Jin},
  title     = {Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition},
  booktitle = {2021 IEEE/CVF International Conference on Computer Vision, {ICCV} 2021},  
  year      = {2021},
}

Open Source

A few files in this repository are modified from the following open-source implementations:

https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
https://github.com/VITA-Group/TENAS
https://github.com/SamsungLabs/zero-cost-nas
https://github.com/BayesWatch/nas-without-training
https://github.com/rwightman/gen-efficientnet-pytorch
https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html

Copyright

Copyright (C) 2010-2021 Alibaba Group Holding Limited.

Comments
  • multi-gpu training

    multi-gpu training

    Training for 480 epoch on single GPU will take quite a lot of time. So I wanted to check if multi GPU training is possible in the current code. I tried --dist_mode with mpi and auto are not supported as global_utils.AutoGPU() is not defined. But looks like --dist_mode=horovod is supported. Can it be used to do multi GPU training? Is it possible to share sample command to do multi-gpu training?

    Btw it took around 2.83 hrs to complete 1 epoch on my GTX 1080-TI, does this sound reasonable or too high? Typically I have seen on my machine (4 x GTX1080) around 15 mins per epoch for EfficientNet-ES using Ross wightman's timm repo.

    opened by soyebn 9
  • Deeper and wider network has higher accuracy?

    Deeper and wider network has higher accuracy?

    From Figure 2. in the paper, it can be seen that deeper and wider network has higher Zen-Score. And Zen-Score positively correlates with model accuracy. So Deeper and wider network has higher accuracy, which is a well-known principle. Then whats' the meaning of Zen-Score?

    opened by buaabai 6
  • Hi MingLin

    Hi MingLin

    Your proposed Zen-NAS is a very efficient way to search for neural network structures. I read your article and GitHub code carefully, and did my own search on your code, but one thing I found is that the network structure searched almost always repeats more times the deeper the network is the network block, and the first few layers of the network block are repeated once, for example, I used your code to search the structure of MNas (the search space has been changed according to MNas), MNas0.35 optimal structure: 
    

    SuperConvK3BNRELU(3,16,2,1)SuperResMnasV1K3(16,8,1,16,1)SuperResMnasV3K3(8,8,2,8,3)SuperResMnasV3K5(8,16,2,8,3) SuperResMnasV6K5(16,32,2,16,3)SuperResMnasV6K3(32,32,1,32,2)SuperResMnasV6K5(32,64,2,32,4)SuperResMnasV6K3(64,112,1,64,1) SuperConvK1BNRELU(112,1280,1,1) but I searched with your architecture and the structure is as follows SuperConvK3BNRELU(3,8,2,1)SuperResMnasV1K3(8,8,1,8,1)SuperResMnasV3K5(8,16,2,8,1)SuperResMnasV3K5(16,24,2,8,1)SuperResMnasV3K5(24,64,2,40,1)SuperResMnasV3K5(64,24,1,48,1)SuperResMnasV3K5(24,64,2,176,4)SuperResMnasV3K5(64,48,1,256,5)SuperConvK1BNRELU(48,2048,1, 1) so the search out of the structure compared to the original structure is not very good, the structure is still the problem mentioned above, the search out of the network shallow block duplication for 1, only the deeper network has block duplication, also tested your code in the Flops400M,600M,900M, found the same problem, this is why?

    opened by billbig 5
  • is it possible for you to release the retrain log file?

    is it possible for you to release the retrain log file?

    hi, @MingLin-home , In paper and official code, retrain the searched model by the feature loss cost too much resource, the 1.2ms latency model is very difficult to train completely end to end. so whether the series model‘s log can be released?

    image
    opened by aptsunny 4
  • NAS-Bench-201

    NAS-Bench-201

    Hello,

    As I remember from the paper, your method works on Vanilla CNN. However, in algorithm 1, you just mentioned that residual connections are deleted.

    It confused me a little bit, and I do not know your method can be applied on any CNN without residual connection or it only can be applied on Vanilla CNN.

    Can I use your code on benchmarks such as NAS-Bench-201?

    opened by farhad-dalirani 3
  • Some simple questions

    Some simple questions

    Dear researcher, Hello
    I ran into some problems while running the code that I couldn't solve when I wanted to test CIFAR10
    (python val_cifar.py --dataset cifar10 --gpu 0 --arch zennet_cifar10_model_size05M_res32) image

    RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.image

    The version I use is Torch1.7.1 Torchvision0.8.2 cudatoolkit10.2.89. May I ask if this is a problem with my operating environment? Or do I need to make changes in the code? Thanks a million, dear researcher.

    opened by fly2tortoise 3
  • Same search space used in search_space_IDW_fixfc.py as well as in search_space_XXBL.py

    Same search space used in search_space_IDW_fixfc.py as well as in search_space_XXBL.py

    First of all, thanks a lot for releasing the code for such a nice work. I have few doubts,

    1. The seach_space_block_type_list_list in both files, search_space_IDW_fixfc.py as well as search_space_XXBL.py are same. Is it intended?

    2. The script, Zen_NAS_ImageNet_flops400M.sh and Zen_NAS_ImageNet_latency0.2ms use same search space. But the paper mentions Zen_NAS_ImageNet_flops400M.sh using MB block and Zen_NAS_ImageNet_latency0.2ms uses botn block like Resnet50.

    Thanks again.

    opened by soyebn 3
  • A training question about GENet

    A training question about GENet

    Hi, everyone. I'm currently trying to reproduce your previous work "Neural Architecture Design for GPU-Efficient Networks", here is the repo link. After determining the model structure, such as "GENet_large", "GENet_small", etc. Can we refer to the training script(train_image_classification.py) in this repo to train the model and get the effect consistent with the paper description?

    opened by FisherWY 2
  • How to do `entropy_forward` for CSP network?

    How to do `entropy_forward` for CSP network?

    My block look like this.

    BottleneckCSP(
      (cv1): Conv(
        (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Mish()
      )
      (cv2): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (cv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (cv4): Conv(
        (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Mish()
      )
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): Mish()
      (m): Sequential(
        (0): Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): Mish()
          )
          (cv2): Conv(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): Mish()
          )
        )
        (1): Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): Mish()
          )
          (cv2): Conv(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): Mish()
          )
        )
      )
    )
    

    forward(self, x)

    def forward(self, x):
            d = self.m(self.cv1(x))
            y1 = self.cv3(d)
            y2 = self.cv2(x)
            return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
    
    opened by 1chimaruGin 2
  • Use of NTK Condition Number in Combination with ZEN Score.

    Use of NTK Condition Number in Combination with ZEN Score.

    Hello MingLin,

    I found your NAS approach to be very interesting.

    Did you ever try to combine the ZEN-Score with the NTK-score like the TE-NAS paper did? The TE-NAS paper suggests that combining the NTK-Score with a expressivity measure would improve the overall performance.

    Thanks, SXK

    opened by SXKJames 2
  • Does ZenScore Support MobileNet-like Search Space?

    Does ZenScore Support MobileNet-like Search Space?

    Thank you for your great work.

    I have tried zenscore and get great kendall tau in plain network.

    I wonder if zenscore can apply to MobileNet-like Search Space?

    Looking forward to your reply.

    opened by pprp 2
  • GENet how to use soft-labels

    GENet how to use soft-labels

    i follow the paper "Neural Architecture Design for GPU-Efficient Networks", and see "We also use ResNet-152 as teacher network to get soft-labels. The soft-label loss and the true-label loss are weighted 1:1." at page 8 how to use soft-labels,teach genet. like knowledge distillation? have any code in this repo?

    opened by bigcash 6
Releases(iccv-2021-v1.1)
Owner
A vision team from Alibaba
Colab notebook for openai/glide-text2im.

GLIDE text2im on Colab This repository provides a Colab notebook to produce images conditioned on text prompts with GLIDE [1]. Usage Run text2im.ipynb

Wok 19 Oct 19, 2022
MinkLoc3D-SI: 3D LiDAR place recognition with sparse convolutions,spherical coordinates, and intensity

MinkLoc3D-SI: 3D LiDAR place recognition with sparse convolutions,spherical coordinates, and intensity Introduction The 3D LiDAR place recognition aim

16 Dec 08, 2022
Experiments for Neural Flows paper

Neural Flows: Efficient Alternative to Neural ODEs [arxiv] TL;DR: We directly model the neural ODE solutions with neural flows, which is much faster a

54 Dec 07, 2022
Instance-wise Feature Importance in Time (FIT)

Instance-wise Feature Importance in Time (FIT) FIT is a framework for explaining time series perdiction models, by assigning feature importance to eve

Sana 46 Dec 25, 2022
This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Information Maximization for Multimodal Sentiment Analysis, accepted at EMNLP 2021.

MultiModal-InfoMax This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Informa

Deep Cognition and Language Research (DeCLaRe) Lab 89 Dec 26, 2022
A python-image-classification web application project, written in Python and served through the Flask Microframework

A python-image-classification web application project, written in Python and served through the Flask Microframework. This Project implements the VGG16 covolutional neural network, through Keras and

Gerald Maduabuchi 19 Dec 12, 2022
Minimalist Error collection Service compatible with Rollbar clients. Sentry or Rollbar alternative.

Minimalist Error collection Service Features Compatible with any Rollbar client(see https://docs.rollbar.com/docs). Just change the endpoint URL to yo

Haukur Rósinkranz 381 Nov 11, 2022
Python wrappers to the C++ library SymEngine, a fast C++ symbolic manipulation library.

SymEngine Python Wrappers Python wrappers to the C++ library SymEngine, a fast C++ symbolic manipulation library. Installation Pip See License section

136 Dec 28, 2022
Repository for the paper : Meta-FDMixup: Cross-Domain Few-Shot Learning Guided byLabeled Target Data

1 Meta-FDMIxup Repository for the paper : Meta-FDMixup: Cross-Domain Few-Shot Learning Guided byLabeled Target Data. (ACM MM 2021) paper News! the rep

Fu Yuqian 44 Nov 18, 2022
Real-time multi-object tracker using YOLO v5 and deep sort

This repository contains a two-stage-tracker. The detections generated by YOLOv5, a family of object detection architectures and models pretrained on the COCO dataset, are passed to a Deep Sort algor

Mike 3.6k Jan 05, 2023
[3DV 2021] A Dataset-Dispersion Perspective on Reconstruction Versus Recognition in Single-View 3D Reconstruction Networks

dispersion-score Official implementation of 3DV 2021 Paper A Dataset-dispersion Perspective on Reconstruction versus Recognition in Single-view 3D Rec

Yefan 7 May 28, 2022
Numerical-computing-is-fun - Learning numerical computing with notebooks for all ages.

As much as this series is to educate aspiring computer programmers and data scientists of all ages and all backgrounds, it is also a reminder to mysel

EKA foundation 758 Dec 25, 2022
Provably Rare Gem Miner.

Provably Rare Gem Miner just another random project by yoyoismee.eth useful link main site market contract useful thing you should know read contract

34 Nov 22, 2022
TensorFlow implementation of the algorithm in the paper "Decoupled Low-light Image Enhancement"

Decoupled Low-light Image Enhancement Shijie Hao1,2*, Xu Han1,2, Yanrong Guo1,2 & Meng Wang1,2 1Key Laboratory of Knowledge Engineering with Big Data

17 Apr 25, 2022
Offical code for the paper: "Growing 3D Artefacts and Functional Machines with Neural Cellular Automata" https://arxiv.org/abs/2103.08737

Growing 3D Artefacts and Functional Machines with Neural Cellular Automata Video of more results: https://www.youtube.com/watch?v=-EzztzKoPeo Requirem

Robotics Evolution and Art Lab 51 Jan 01, 2023
CVPR2022 (Oral) - Rethinking Semantic Segmentation: A Prototype View

Rethinking Semantic Segmentation: A Prototype View Rethinking Semantic Segmentation: A Prototype View, Tianfei Zhou, Wenguan Wang, Ender Konukoglu and

Tianfei Zhou 239 Dec 26, 2022
Framework that uses artificial intelligence applied to mathematical models to make predictions

LiconIA Framework that uses artificial intelligence applied to mathematical models to make predictions Interface Overview Table of contents [TOC] 1 Ar

4 Jun 20, 2021
I-BERT: Integer-only BERT Quantization

I-BERT: Integer-only BERT Quantization HuggingFace Implementation I-BERT is also available in the master branch of HuggingFace! Visit the following li

Sehoon Kim 139 Dec 27, 2022
Code and real data for the paper "Counterfactual Temporal Point Processes", available at arXiv.

counterfactual-tpp This is a repository containing code and real data for the paper Counterfactual Temporal Point Processes. Pre-requisites This code

Networks Learning 11 Dec 09, 2022
NLP From Scratch Without Large-Scale Pretraining: A Simple and Efficient Framework

NLP From Scratch Without Large-Scale Pretraining This repository contains the code, pre-trained model checkpoints and curated datasets for our paper:

Xingcheng Yao 224 Dec 08, 2022