NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size

Overview

NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size

Xuanyi Dong, Lu Liu, Katarzyna Musial, Bogdan Gabrys

in IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI), 2021

Abstract: Neural architecture search (NAS) has attracted a lot of attention and has been illustrated to bring tangible benefits in a large number of applications in the past few years. Network topology and network size have been regarded as two of the most important aspects for the performance of deep learning models and the community has spawned lots of searching algorithms for both of those aspects of the neural architectures. However, the performance gain from these searching algorithms is achieved under different search spaces and training setups. This makes the overall performance of the algorithms incomparable and the improvement from a sub-module of the searching model unclear. In this paper, we propose NATS-Bench, a unified benchmark on searching for both topology and size, for (almost) any up-to-date NAS algorithm. NATS-Bench includes the search space of 15,625 neural cell candidates for architecture topology and 32,768 for architecture size on three datasets. We analyze the validity of our benchmark in terms of various criteria and performance comparison of all candidates in the search space. We also show the versatility of NATS-Bench by benchmarking 13 recent state-of-the-art NAS algorithms on it. All logs and diagnostic information trained using the same setup for each candidate are provided. This facilitates a much larger community of researchers to focus on developing better NAS algorithms in a more comparable and computationally effective environment.

You can use pip install nats_bench to install the library of NATS-Bench or install from source by python setup.py install.

If you are seeking how to re-create NATS-Bench from scratch or reproduce benchmarked results, please see use AutoDL-Projects and see these instructions.

If you have questions, please ask at here or email me :)

This figure is the main difference between NATS-Bench, NAS-Bench-101, and NAS-Bench-201. The topology search space ($\mathcal{S}_t$) in NATS-Bench is the same as NAS-Bench-201, while we upgrade with results of more runs for the architecture candidates, and the benchmarked NAS algorithms have better hyperparameters.

Preparation and Download

Step-1: download raw vision datasets. (you can skip this one if you do not use weight-sharing NAS or re-create NATS-Bench).

In NATS-Bench, we (create and) use three image datasets -- CIFAR-10, CIFAR-100, and ImageNet16-120. For more details, please see Sec-3.2 in the NATS-Bench paper. To download these three datasets, please find them at Google Drive. To create the ImageNet16-120 PyTorch dataset, please call AutoDL-Projects/lib/datasets/ImageNet16, by using:

train_data = ImageNet16(root, True , train_transform, 120)
test_data  = ImageNet16(root, False, test_transform , 120)

Step-2: download benchmark files of NATS-Bench.

The latest benchmark file of NATS-Bench can be downloaded from Google Drive. After download NATS-[tss/sss]-[version]-[md5sum]-simple.tar, please uncompress it by using tar xvf [file_name]. We highly recommend to put the downloaded benchmark file (NATS-sss-v1_0-50262.pickle.pbz2 / NATS-tss-v1_0-3ffb9.pickle.pbz2) or uncompressed archive (NATS-sss-v1_0-50262-simple / NATS-tss-v1_0-3ffb9-simple) into $TORCH_HOME. In this way, our api will automatically find the path for these benchmark files, which are convenient for the users. Otherwise, you need to indicate the file when creating the benchmark instance manually.

The history of benchmark files is as follows, tss indicates the topology search space and sss indicates the size search space. The benchmark file is used when creating the NATS-Bench instance with fast_mode=False. The archive is used when fast_mode=True, where archive is a directory containing 15,625 files for tss or contains 32,768 files for sss. Each file contains all the information for a specific architecture candidate. The full archive is similar to archive, while each file in full archive contains the trained weights. Since the full archive is too large, we use split -b 30G file_name file_name to split it into multiple 30G chunks. To merge the chunks into the original full archive, you can use cat file_name* > file_name.

Date benchmark file (tss) archive (tss) full archive (tss) benchmark file (sss) archive (sss) full archive (sss)
2020.08.31 NATS-tss-v1_0-3ffb9.pickle.pbz2 NATS-tss-v1_0-3ffb9-simple.tar NATS-tss-v1_0-3ffb9-full NATS-sss-v1_0-50262.pickle.pbz2 NATS-sss-v1_0-50262-simple.tar NATS-sss-v1_0-50262-full
2021.04.22 (Baidu-Pan) NATS-tss-v1_0-3ffb9.pickle.pbz2 (code: 8duj) NATS-tss-v1_0-3ffb9-simple.tar (code: tu1e) NATS-tss-v1_0-3ffb9-full (code:ssub) NATS-sss-v1_0-50262.pickle.pbz2 (code: za2h) NATS-sss-v1_0-50262-simple.tar (code: e4t9) NATS-sss-v1_0-50262-full (code: htif)

These benchmark files (without pretrained weights) can also be downloaded from Dropbox, OneDrive or Baidu-Pan (extract code: h6pm).

For the full checkpoints in NATS-*ss-*-full, we split the file into multiple parts (NATS-*ss-*-full.tara*) since they are too large to upload. Each file is about 30GB. For Baidu Pan, since they restrict the maximum size of each file, we further split NATS-*ss-*-full.tara* into NATS-*ss-*-full.tara*-aa and NATS-*ss-*-full.tara*-ab. All splits are created by the command split.

Note: if you encounter the quota exceed erros when download from Google Drive, please try to (1) login your personal Google account, (2) right-click-copy the files to your personal Google Drive, and (3) download from your personal Google Drive.

Usage

See more examples at notebooks.

1, create the benchmark instance:

from nats_bench import create
# Create the API instance for the size search space in NATS
api = create(None, 'sss', fast_mode=True, verbose=True)

# Create the API instance for the topology search space in NATS
api = create(None, 'tss', fast_mode=True, verbose=True)

2, query the performance:

# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10
# info is a dict, where you can easily figure out the meaning by key
info = api.get_more_info(1234, 'cifar10')

# Query the flops, params, latency. info is a dict.
info = api.get_cost_info(12, 'cifar10')

# Simulate the training of the 1224-th candidate:
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')

3, create the instance of an architecture candidate in NATS-Bench:

# Create the instance of th 12-th candidate for CIFAR-10.
# To keep NATS-Bench repo concise, we did not include any model-related codes here because they rely on PyTorch.
# The package of [models] is defined at https://github.com/D-X-Y/AutoDL-Projects
#   so that one need to first import this package.
import xautodl
from xautodl.models import get_cell_based_tiny_net
config = api.get_net_config(12, 'cifar10')
network = get_cell_based_tiny_net(config)

# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.
params = api.get_net_param(12, 'cifar10', None)
network.load_state_dict(next(iter(params.values())))

4, others:

# Clear the parameters of the 12-th candidate.
api.clear_params(12)

# Reload all information of the 12-th candidate.
api.reload(index=12)

Please see api_test.py for more examples.

from nats_bench import api_test
api_test.test_nats_bench_tss('NATS-tss-v1_0-3ffb9-simple')
api_test.test_nats_bench_tss('NATS-sss-v1_0-50262-simple')

How to Re-create NATS-Bench from Scratch

You need to use the AutoDL-Projects repo to re-create NATS-Bench from scratch.

The Size Search Space

The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of 777. If you want to use a different number of training epochs, please replace 90 with it, such as 01 or 12.

bash ./scripts/NATS-Bench/train-shapes.sh 00000-32767 90 777

The checkpoint of all candidates are located at output/NATS-Bench-size by default.

After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file.

python exps/NATS-Bench/sss-collect.py

The Topology Search Space

The following command will train all architecture candidate in the topology search space with 200 epochs and use the random seed of 777/888/999. If you want to use a different number of training epochs, please replace 200 with it, such as 12.

bash scripts/NATS-Bench/train-topology.sh 00000-15624 200 '777 888 999'

The checkpoint of all candidates are located at output/NATS-Bench-topology by default.

After training these candidate architectures, please use the following command to re-organize all checkpoints into the official benchmark file.

python exps/NATS-Bench/tss-collect.py

To Reproduce 13 Baseline NAS Algorithms in NATS-Bench

You need to use the AutoDL-Projects repo to run 13 baseline NAS methods. Here are a brief introduction on how to run each algorithm (NATS-algos).

Reproduce NAS methods on the topology search space

Please use the following commands to run different NAS methods on the topology search space:

Four multi-trial based methods:
python ./exps/NATS-algos/reinforce.py       --dataset cifar100 --search_space tss --learning_rate 0.01
python ./exps/NATS-algos/regularized_ea.py  --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3
python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space tss
python ./exps/NATS-algos/bohb.py            --dataset cifar100 --search_space tss --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3

DARTS (first order):
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v1
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1

DARTS (second order):
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v2
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2

GDAS:
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16

SETN:
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo setn
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn

Random Search with Weight Sharing:
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo random
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random

ENAS:
python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001

Reproduce NAS methods on the size search space

Please use the following commands to run different NAS methods on the size search space:

Four multi-trial based methods:
python ./exps/NATS-algos/reinforce.py       --dataset cifar100 --search_space sss --learning_rate 0.01
python ./exps/NATS-algos/regularized_ea.py  --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3
python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space sss
python ./exps/NATS-algos/bohb.py            --dataset cifar100 --search_space sss --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3


Run Transformable Architecture Search (TAS), proposed in Network Pruning via Transformable Architecture Search, NeurIPS 2019

python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777


Run the channel search strategy in FBNet-V2 -- masking + Gumbel-Softmax :

python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777


Run the channel search strategy in TuNAS -- masking + sampling :

python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777

Final Discovered Architectures for Each Algorithm

The architecture index can be found by use api.query_index_by_arch(architecture_string).

The final discovered architecture ID on CIFAR-10:

DARTS (first order):
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|

DARTS (second order):
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|
|skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|skip_connect~1|skip_connect~2|

GDAS:
|nor_conv_3x3~0|+|nor_conv_3x3~0|none~1|+|nor_conv_1x1~0|nor_conv_3x3~1|nor_conv_3x3~2|
|nor_conv_3x3~0|+|nor_conv_3x3~0|none~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|
|avg_pool_3x3~0|+|nor_conv_3x3~0|skip_connect~1|+|nor_conv_3x3~0|nor_conv_1x1~1|nor_conv_1x1~2|

The final discovered architecture ID on CIFAR-100:

DARTS (V1):
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|none~2|
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|none~2|
|skip_connect~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|nor_conv_3x3~2|

DARTS (V2):
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|nor_conv_1x1~1|skip_connect~2|
|skip_connect~0|+|nor_conv_3x3~0|none~1|+|skip_connect~0|none~1|none~2|
|skip_connect~0|+|nor_conv_1x1~0|none~1|+|nor_conv_3x3~0|skip_connect~1|none~2|

GDAS:
|nor_conv_3x3~0|+|nor_conv_1x1~0|none~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|
|avg_pool_3x3~0|+|nor_conv_1x1~0|none~1|+|nor_conv_3x3~0|avg_pool_3x3~1|nor_conv_1x1~2|
|avg_pool_3x3~0|+|nor_conv_3x3~0|none~1|+|nor_conv_3x3~0|nor_conv_1x1~1|nor_conv_1x1~2|

The final discovered architecture ID on ImageNet-16-120:

DARTS (V1):
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|nor_conv_1x1~2|

DARTS (V2):
|none~0|+|skip_connect~0|none~1|+|skip_connect~0|none~1|skip_connect~2|

GDAS:
|none~0|+|none~0|none~1|+|nor_conv_3x3~0|none~1|none~2|
|none~0|+|none~0|none~1|+|nor_conv_3x3~0|none~1|none~2|
|none~0|+|none~0|none~1|+|nor_conv_3x3~0|none~1|none~2|

Others

We use black for Python code formatter. Please use black . -l 120.

Citation

If you find that NATS-Bench helps your research, please consider citing it:

@article{dong2021nats,
  title   = {{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
  author  = {Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
  doi     = {10.1109/TPAMI.2021.3054824},
  journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
  year    = {2021},
  note    = {\mbox{doi}:\url{10.1109/TPAMI.2021.3054824}}
}
@inproceedings{dong2020nasbench201,
  title     = {{NAS-Bench-201}: Extending the Scope of Reproducible Neural Architecture Search},
  author    = {Dong, Xuanyi and Yang, Yi},
  booktitle = {International Conference on Learning Representations (ICLR)},
  url       = {https://openreview.net/forum?id=HJxyZkBKDr},
  year      = {2020}
}
Owner
D-X-Y
Research Scientist on AutoDL.
D-X-Y
A Library for Modelling Probabilistic Hierarchical Graphical Models in PyTorch

A Library for Modelling Probabilistic Hierarchical Graphical Models in PyTorch

Korbinian Pöppel 47 Nov 28, 2022
Lolviz - A simple Python data-structure visualization tool for lists of lists, lists, dictionaries; primarily for use in Jupyter notebooks / presentations

lolviz By Terence Parr. See Explained.ai for more stuff. A very nice looking javascript lolviz port with improvements by Adnan M.Sagar. A simple Pytho

Terence Parr 785 Dec 30, 2022
StyleGAN2-ADA - Official PyTorch implementation

Need Help? If you’re new to StyleGAN2-ADA and looking to get started, please check out this video series from a course Lia Coleman and I taught in Oct

Derrick Schultz 217 Jan 04, 2023
A general framework for inferring CNNs efficiently. Reduce the inference latency of MobileNet-V3 by 1.3x on an iPhone XS Max without sacrificing accuracy.

GFNet-Pytorch (NeurIPS 2020) This repo contains the official code and pre-trained models for the glance and focus network (GFNet). Glance and Focus: a

Rainforest Wang 169 Oct 28, 2022
Lenia - Mathematical Life Forms

For full version list, see Timeline in Lenia portal [2020-10-13] Update Python version with multi-kernel and multi-channel extensions (v3.4 LeniaNDK.p

Bert Chan 3.1k Dec 28, 2022
PyMatting: A Python Library for Alpha Matting

Given an input image and a hand-drawn trimap (top row), alpha matting estimates the alpha channel of a foreground object which can then be composed onto a different background (bottom row).

PyMatting 1.4k Dec 30, 2022
Official PyTorch implementation of "Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks" (AAAI 2022)

Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks This is the code for reproducing the results of th

2 Dec 27, 2021
A Japanese Medical Information Extraction Toolkit

JaMIE: a Japanese Medical Information Extraction toolkit Joint Japanese Medical Problem, Modality and Relation Recognition The Train/Test phrases requ

7 Dec 12, 2022
Multi agent DDPG algorithm written in Python + Pytorch

Multi agent DDPG algorithm written in Python + Pytorch. It also includes a Jupyter notebook, Tennis.ipynb, as a showcase.

Rogier Wachters 2 Feb 26, 2022
It is the assignment for COMP 576 in Rice University

COMP-576 It is the assignment for COMP 576 in Rice University There are two programming assignments and one Final Project. Assignment 1: It is a MLP a

Maojie Tang 1 Nov 25, 2021
A Python 3 package for state-of-the-art statistical dimension reduction methods

direpack: a Python 3 library for state-of-the-art statistical dimension reduction techniques This package delivers a scikit-learn compatible Python 3

Sven Serneels 32 Dec 14, 2022
Pull sensitive data from users on windows including discord tokens and chrome data.

⭐ For a 🍪 Pegasus Pull sensitive data from users on windows including discord tokens and chrome data. Features 🟩 Discord tokens 🟩 Geolocation data

Addi 44 Dec 31, 2022
An Image compression simulator that uses Source Extractor and Monte Carlo methods to examine the post compressive effects different compression algorithms have.

ImageCompressionSimulation An Image compression simulator that uses Source Extractor and Monte Carlo methods to examine the post compressive effects o

James Park 1 Dec 11, 2021
Using PyTorch Perform intent classification using three different models to see which one is better for this task

Using PyTorch Perform intent classification using three different models to see which one is better for this task

Yoel Graumann 1 Feb 14, 2022
Python version of the amazing Reaction Mechanism Generator (RMG).

Reaction Mechanism Generator (RMG) Description This repository contains the Python version of Reaction Mechanism Generator (RMG), a tool for automatic

Reaction Mechanism Generator 284 Dec 27, 2022
TensorFlow, PyTorch and Numpy layers for generating Orthogonal Polynomials

OrthNet TensorFlow, PyTorch and Numpy layers for generating multi-dimensional Orthogonal Polynomials 1. Installation 2. Usage 3. Polynomials 4. Base C

Chuan 29 May 25, 2022
ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

Double-zh 45 Dec 19, 2022
PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimation

StructDepth PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimat

SJTU-ViSYS 112 Nov 28, 2022
MARE - Multi-Attribute Relation Extraction

MARE - Multi-Attribute Relation Extraction Repository for the paper submission: #TODO: insert link, when available Environment Tested with Ubuntu 18.0

0 May 11, 2021
PyTorch DepthNet Training on Still Box dataset

DepthNet training on Still Box Project page This code can replicate the results of our paper that was published in UAVg-17. If you use this repo in yo

Clément Pinard 115 Nov 21, 2022