Differentiable architecture search for convolutional and recurrent networks

Overview

Differentiable Architecture Search

Code accompanying the paper

DARTS: Differentiable Architecture Search
Hanxiao Liu, Karen Simonyan, Yiming Yang.
arXiv:1806.09055.

darts

The algorithm is based on continuous relaxation and gradient descent in the architecture space. It is able to efficiently design high-performance convolutional architectures for image classification (on CIFAR-10 and ImageNet) and recurrent architectures for language modeling (on Penn Treebank and WikiText-2). Only a single GPU is required.

Requirements

Python >= 3.5.5, PyTorch == 0.3.1, torchvision == 0.2.0

NOTE: PyTorch 0.4 is not supported at this moment and would lead to OOM.

Datasets

Instructions for acquiring PTB and WT2 can be found here. While CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) following the instructions here.

Pretrained models

The easist way to get started is to evaluate our pretrained DARTS models.

CIFAR-10 (cifar10_model.pt)

cd cnn && python test.py --auxiliary --model_path cifar10_model.pt
  • Expected result: 2.63% test error rate with 3.3M model params.

PTB (ptb_model.pt)

cd rnn && python test.py --model_path ptb_model.pt
  • Expected result: 55.68 test perplexity with 23M model params.

ImageNet (imagenet_model.pt)

cd cnn && python test_imagenet.py --auxiliary --model_path imagenet_model.pt
  • Expected result: 26.7% top-1 error and 8.7% top-5 error with 4.7M model params.

Architecture search (using small proxy models)

To carry out architecture search using 2nd-order approximation, run

cd cnn && python train_search.py --unrolled     # for conv cells on CIFAR-10
cd rnn && python train_search.py --unrolled     # for recurrent cells on PTB

Note the validation performance in this step does not indicate the final performance of the architecture. One must train the obtained genotype/architecture from scratch using full-sized models, as described in the next section.

Also be aware that different runs would end up with different local minimum. To get the best result, it is crucial to repeat the search process with different seeds and select the best cell(s) based on validation performance (obtained by training the derived cell from scratch for a small number of epochs). Please refer to fig. 3 and sect. 3.2 in our arXiv paper.

progress_convolutional_normal progress_convolutional_reduce progress_recurrent

Figure: Snapshots of the most likely normal conv, reduction conv, and recurrent cells over time.

Architecture evaluation (using full-sized models)

To evaluate our best cells by training from scratch, run

cd cnn && python train.py --auxiliary --cutout            # CIFAR-10
cd rnn && python train.py                                 # PTB
cd rnn && python train.py --data ../data/wikitext-2 \     # WT2
            --dropouth 0.15 --emsize 700 --nhidlast 700 --nhid 700 --wdecay 5e-7
cd cnn && python train_imagenet.py --auxiliary            # ImageNet

Customized architectures are supported through the --arch flag once specified in genotypes.py.

The CIFAR-10 result at the end of training is subject to variance due to the non-determinism of cuDNN back-prop kernels. It would be misleading to report the result of only a single run. By training our best cell from scratch, one should expect the average test error of 10 independent runs to fall in the range of 2.76 +/- 0.09% with high probability.

cifar10 ptb ptb

Figure: Expected learning curves on CIFAR-10 (4 runs), ImageNet and PTB.

Visualization

Package graphviz is required to visualize the learned cells

python visualize.py DARTS

where DARTS can be replaced by any customized architectures in genotypes.py.

Citation

If you use any part of this code in your research, please cite our paper:

@article{liu2018darts,
  title={DARTS: Differentiable Architecture Search},
  author={Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
  journal={arXiv preprint arXiv:1806.09055},
  year={2018}
}
Owner
Hanxiao Liu
Research Scientist @ Google Brain
Hanxiao Liu
The Official TensorFlow Implementation for SPatchGAN (ICCV2021)

SPatchGAN: Official TensorFlow Implementation Paper "SPatchGAN: A Statistical Feature Based Discriminator for Unsupervised Image-to-Image Translation"

39 Dec 30, 2022
A fast model to compute optical flow between two input images.

DCVNet: Dilated Cost Volumes for Fast Optical Flow This repository contains our implementation of the paper: @InProceedings{jiang2021dcvnet, title={

Huaizu Jiang 8 Sep 27, 2021
This repository contains the code for "Self-Diagnosis and Self-Debiasing: A Proposal for Reducing Corpus-Based Bias in NLP".

Self-Diagnosis and Self-Debiasing This repository contains the source code for Self-Diagnosis and Self-Debiasing: A Proposal for Reducing Corpus-Based

Timo Schick 62 Dec 12, 2022
Official Implementation for Fast Training of Neural Lumigraph Representations using Meta Learning.

Fast Training of Neural Lumigraph Representations using Meta Learning Project Page | Paper | Data Alexander W. Bergman, Petr Kellnhofer, Gordon Wetzst

Alex 39 Oct 08, 2022
[CVPR 2021] Modular Interactive Video Object Segmentation: Interaction-to-Mask, Propagation and Difference-Aware Fusion

[CVPR 2021] Modular Interactive Video Object Segmentation: Interaction-to-Mask, Propagation and Difference-Aware Fusion

Rex Cheng 364 Jan 03, 2023
Unofficial implementation of Google "CutPaste: Self-Supervised Learning for Anomaly Detection and Localization" in PyTorch

CutPaste CutPaste: image from paper Unofficial implementation of Google's "CutPaste: Self-Supervised Learning for Anomaly Detection and Localization"

Lilit Yolyan 59 Nov 27, 2022
Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight)

Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight) Abstract Due to the limited and even imbalanced dat

Hanzhe Hu 99 Dec 12, 2022
Python library to receive live stream events like comments and gifts in realtime from TikTok LIVE.

TikTokLive A python library to connect to and read events from TikTok's LIVE service A python library to receive and decode livestream events such as

Isaac Kogan 277 Dec 23, 2022
Code for Blind Image Decomposition (BID) and Blind Image Decomposition network (BIDeN).

arXiv, porject page, paper Blind Image Decomposition (BID) Blind Image Decomposition is a novel task. The task requires separating a superimposed imag

64 Dec 20, 2022
Code for the IJCAI 2021 paper "Structure Guided Lane Detection"

SGNet Project for the IJCAI 2021 paper "Structure Guided Lane Detection" Abstract Recently, lane detection has made great progress with the rapid deve

Jinming Su 27 Dec 08, 2022
This is the repository for the paper "Have I done enough planning or should I plan more?"

Metacognitive Learning Tool box https://re.is.mpg.de What Is This? This repository contains two modules used to analyse metacognitive learning in huma

0 Dec 01, 2021
Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators

Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators This is our Pytorch implementation for t

RUCAIBox 12 Jul 22, 2022
Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM

Class Activation Map methods implemented in Pytorch pip install grad-cam ⭐ Tested on many Common CNN Networks and Vision Transformers. ⭐ Includes smoo

Jacob Gildenblat 6.6k Jan 06, 2023
Ego4d dataset repository. Download the dataset, visualize, extract features & example usage of the dataset

Ego4D EGO4D is the world's largest egocentric (first person) video ML dataset and benchmark suite, with 3,600 hrs (and counting) of densely narrated v

Meta Research 118 Jan 07, 2023
K-FACE Analysis Project on Pytorch

Installation Setup with Conda # create a new environment conda create --name insightKface python=3.7 # or over conda activate insightKface #install t

Jung Jun Uk 7 Nov 10, 2022
The FIRST GANs-based omics-to-omics translation framework

OmiTrans Please also have a look at our multi-omics multi-task DL freamwork 👀 : OmiEmbed The FIRST GANs-based omics-to-omics translation framework Xi

Xiaoyu Zhang 6 Dec 14, 2022
Streamlit Tutorial (ex: stock price dashboard, cartoon-stylegan, vqgan-clip, stylemixing, styleclip, sefa)

Streamlit Tutorials Install pip install streamlit Run cd [directory] streamlit run app.py --server.address 0.0.0.0 --server.port [your port] # http:/

Jihye Back 30 Jan 06, 2023
DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

DeepSpeed+Megatron trained the world's most powerful language model: MT-530B DeepSpeed is hiring, come join us! DeepSpeed is a deep learning optimizat

Microsoft 8.4k Dec 28, 2022
Uncertainty-aware Semantic Segmentation of LiDAR Point Clouds for Autonomous Driving

SalsaNext: Fast, Uncertainty-aware Semantic Segmentation of LiDAR Point Clouds for Autonomous Driving Abstract In this paper, we introduce SalsaNext f

308 Jan 04, 2023
ZeroVL - The official implementation of ZeroVL

This repository contains source code necessary to reproduce the results presente

31 Nov 04, 2022