CATE: Computation-aware Neural Architecture Encoding with Transformers

Overview

CATE: Computation-aware Neural Architecture Encoding with Transformers

Code for paper:

CATE: Computation-aware Neural Architecture Encoding with Transformers
Shen Yan, Kaiqiang Song, Fei Liu, Mi Zhang.
ICML 2021 (Long Talk).

CATE
Overview of CATE: It takes computationally similar architecture pairs as the input and trained to predict masked operators given the pairwise computation information. Apart from the cross-attention blocks, the pretrained Transformer encoder is used to extract architecture encodings for the downstream search.

The repository is built upon pybnn and nas-encodings.

Requirements

conda create -n tf python=3.7
source activate tf
cat requirements.txt | xargs -n 1 -L 1 pip install

Experiments on NAS-Bench-101

Dataset preparation on NAS-Bench-101

Install nasbench and download nasbench_only108.tfrecord in ./data folder.

python preprocessing/gen_json.py

Data will be saved in ./data/nasbench101.json.

Generate architecture pairs

python preprocessing/data_generate.py --dataset nasbench101 --flag extract_seq
python preprocessing/data_generate.py --dataset nasbench101 --flag build_pair --k 2 --d 2000000 --metric params

The corresponding training data and pairs will be saved in ./data/nasbench101/.

Alternatively, you can download the data train_data.pt, test_data.pt and pair indices train_pair_k2_d2000000_metric_params.pt, test_pair_k2_d2000000_metric_params.pt from here.

Pretraining

bash run_scripts/pretrain_nasbench101.sh

The pretrained models will be saved in ./model/.

Alternatively, you can download the pretrained model nasbench101_model_best.pth from here.

Extract the pretrained encodings

python inference/inference.py --pretrained_path model/nasbench101_model_best.pth.tar --train_data data/nasbench101/train_data.pt --valid_data data/nasbench101/test_data.pt --dataset nasbench101

The extracted embeddings will be saved in ./cate_nasbench101.pt.

Alternatively, you can download the pretrained embeddings cate_nasbench101.pt from here.

Run search experiments on NAS-Bench-101

bash run_scripts/run_search_nasbench101.sh

Search results will be saved in ./nasbench101/.

Experiments on NAS-Bench-301

Dataset preparation

Install nasbench301 and download the xgb_v1.0 and lgb_runtime_v1.0 file. You may need to make pytorch_geometric compatible with Pytorch and CUDA version.

python preprocessing/gen_json_darts.py # randomly sample 1,000,000 archs

Data will be saved in ./data/nasbench301_proxy.json.

Alternatively, you can download the json file nasbench301_proxy.json from here.

Generate architecture pairs

python preprocessing/data_generate.py --dataset nasbench301 --flag extract_seq
python preprocessing/data_generate.py --dataset nasbench301 --flag build_pair --k 1 --d 5000000 --metric flops

The correspoding training data and pairs will be saved in ./data/nasbench301/.

Alternatively, you can download the data train_data.pt, test_data.pt and pair indices train_pair_k1_d5000000_metric_flops.pt, test_pair_k1_d5000000_metric_flops.pt from here.

Pretraining

bash run_scripts/pretrain_nasbench301.sh

The pretrained models will be saved in ./model/.

Alternatively, you can download the pretrained model nasbench301_model_best.pth from here.

Extract the pretrained encodings

python inference/inference.py --pretrained_path model/nasbench301_model_best.pth.tar --train_data data/nasbench301/train_data.pt --valid_data data/nasbench301/test_data.pt --dataset nasbench301 --n_vocab 11

The extracted encodings will be saved in ./cate_nasbench301.pt.

Alternatively, you can download the pretrained embeddings cate_nasbench301.pt from here.

Run search experiments on NAS-Bench-301

bash run_scripts/run_search_nasbench301.sh

Search results will be saved in ./nasbench301/.

DARTS experiments without surrogate models

Download the pretrained embeddings cate_darts.pt from here.

python search_methods/dngo_ls_darts.py --dim 64 --init_size 16 --topk 5 --dataset darts --output_path bo  --embedding_path cate_darts.pt

Search log will be saved in ./darts/. Final search result will be saved in ./darts/bo/dim64.

Evaluate the learned cell on DARTS Search Space on CIFAR-10

python darts/cnn/train.py --auxiliary --cutout --arch cate_small
python darts/cnn/train.py --auxiliary --cutout --arch cate_large
  • Expected results (CATE-Small): 2.55% avg. test error with 3.5M model params.
  • Expected results (CATE-Large): 2.46% avg. test error with 4.1M model params.

Transfer learning on ImageNet

python darts/cnn/train_imagenet.py  --arch cate_small --seed 1 
python darts/cnn/train_imagenet.py  --arch cate_large --seed 1
  • Expected results (CATE-Small): 26.05% test error with 5.0M model params and 556M mult-adds.
  • Expected results (CATE-Large): 25.01% test error with 5.8M model params and 642M mult-adds.

Visualize the learned cell

python darts/cnn/visualize.py cate_small
python darts/cnn/visualize.py cate_large

Experiments on outside search space

Build outside search space dataset

bash run_scripts/generate_oo.sh

Data will be saved in ./data/nasbench101_oo_train.json and ./data/nasbench101_oo_test.json.

Generate architecture pairs

python preprocessing/data_generate_oo.py --flag extract_seq
python preprocessing/data_generate_oo.py --flag build_pair

The corresponding training data and pair indices will be saved in ./data/nasbench101/.

Pretraining

python run.py --do_train --parallel --train_data data/nasbench101/nasbench101_oo_trainSet_train.pt --train_pair data/nasbench101/oo_train_pairs_k2_params_dist2e6.pt  --valid_data data/nasbench101/nasbench101_oo_trainSet_validation.pt --valid_pair data/nasbench101/oo_validation_pairs_k2_params_dist2e6.pt --dataset oo

The pretrained models will be saved in ./model/.

Extract embeddings on outside search space

# Adjacency encoding
python inference/inference_adj.py
# CATE encoding
python inference/inference.py --pretrained_path model/oo_model_best.pth.tar --train_data data/nasbench101/nasbench101_oo_testSet_split1.pt --valid_data data/nasbench101/nasbench101_oo_testSet_split2.pt --dataset oo_nasbench101

The extracted encodings will be saved as ./adj_oo_nasbench101.pt and ./cate_oo_nasbench101.pt.

Alternatively, you can download the data, pair indices, pretrained models, and extracted embeddings from here.

Run MLP predictor experiments on outside search space

for s in {1..500}; do python search_methods/oo_mlp.py --dim 27 --seed $s --init_size 16 --topk 5 --dataset oo_nasbench101 --output_path np_adj  --embedding_path adj_oo_nasbench101.pt; done
for s in {1..500}; do python search_methods/oo_mlp.py --dim 64 --seed $s --init_size 16 --topk 5 --dataset oo_nasbench101 --output_path np_cate  --embedding_path cate_oo_nasbench101.pt; done

Search results will be saved in ./oo_nasbench101.

Citation

If you find this useful for your work, please consider citing:

@InProceedings{yan2021cate,
  title = {CATE: Computation-aware Neural Architecture Encoding with Transformers},
  author = {Yan, Shen and Song, Kaiqiang and Liu, Fei and Zhang, Mi},
  booktitle = {ICML},
  year = {2021}
}
Streamlit component for TensorBoard, TensorFlow's visualization toolkit

streamlit-tensorboard This is a work-in-progress, providing a function to embed TensorBoard, TensorFlow's visualization toolkit, in Streamlit apps. In

Snehan Kekre 27 Nov 13, 2022
Trax — Deep Learning with Clear Code and Speed

Trax — Deep Learning with Clear Code and Speed Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively us

Google 7.3k Dec 26, 2022
A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

張致強 14 Dec 02, 2022
Character-Input - Create a program that asks the user to enter their name and their age

Character-Input Create a program that asks the user to enter their name and thei

PyLaboratory 0 Feb 06, 2022
Official Implementation of SimIPU: Simple 2D Image and 3D Point Cloud Unsupervised Pre-Training for Spatial-Aware Visual Representations

Official Implementation of SimIPU SimIPU: Simple 2D Image and 3D Point Cloud Unsupervised Pre-Training for Spatial-Aware Visual Representations Since

Zhyever 37 Dec 01, 2022
Pretraining Representations For Data-Efficient Reinforcement Learning

Pretraining Representations For Data-Efficient Reinforcement Learning Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Ch

Mila 40 Dec 11, 2022
Code release for General Greedy De-bias Learning

General Greedy De-bias for Dataset Biases This is an extention of "Greedy Gradient Ensemble for Robust Visual Question Answering" (ICCV 2021, Oral). T

4 Mar 15, 2022
Deep Learning Specialization by Andrew Ng, deeplearning.ai.

Deep Learning Specialization on Coursera Master Deep Learning, and Break into AI This is my personal projects for the course. The course covers deep l

Engen 1.5k Jan 07, 2023
Tackling data scarcity in Speech Translation using zero-shot multilingual Machine Translation techniques

Tackling data scarcity in Speech Translation using zero-shot multilingual Machine Translation techniques This repository is derived from the NMTGMinor

Tu Anh Dinh 1 Sep 07, 2022
LoL Runes Recommender With Python

LoL-Runes-Recommender Para ejecutar la aplicación se debe llamar a execute_app.p

Sebastián Salinas 1 Jan 10, 2022
The implementation for the SportsCap (IJCV 2021)

SportsCap: Monocular 3D Human Motion Capture and Fine-grained Understanding in Challenging Sports Videos ProjectPage | Paper | Video | Dataset (Part01

Chen Xin 79 Dec 16, 2022
DCA - Official Python implementation of Delaunay Component Analysis algorithm

Delaunay Component Analysis (DCA) Official Python implementation of the Delaunay

Petra Poklukar 9 Sep 06, 2022
minimizer-space de Bruijn graphs (mdBG) for whole genome assembly

rust-mdbg: Minimizer-space de Bruijn graphs (mdBG) for whole-genome assembly rust-mdbg is an ultra-fast minimizer-space de Bruijn graph (mdBG) impleme

Barış Ekim 148 Dec 01, 2022
Open source Python module for computer vision

About PCV PCV is a pure Python library for computer vision based on the book "Programming Computer Vision with Python" by Jan Erik Solem. More details

Jan Erik Solem 1.9k Jan 06, 2023
Dataloader tools for language modelling

Installation: pip install lm_dataloader Design Philosophy A library to unify lm dataloading at large scale Simple interface, any tokenizer can be inte

5 Mar 25, 2022
The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

Yuki M. Asano 249 Dec 22, 2022
Implementation of "Semi-supervised Domain Adaptive Structure Learning"

Semi-supervised Domain Adaptive Structure Learning - ASDA This repo contains the source code and dataset for our ASDA paper. Illustration of the propo

3 Dec 13, 2021
A graph-to-sequence model for one-step retrosynthesis and reaction outcome prediction.

Graph2SMILES A graph-to-sequence model for one-step retrosynthesis and reaction outcome prediction. 1. Environmental setup System requirements Ubuntu:

29 Nov 18, 2022
PyTorch code accompanying the paper "Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning" (NeurIPS 2021).

HIGL This is a PyTorch implementation for our paper: Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning (NeurIPS 2021). Our cod

Junsu Kim 20 Dec 14, 2022
The BCNet related data and inference model.

BCNet This repository includes the some source code and related dataset of paper BCNet: Learning Body and Cloth Shape from A Single Image, ECCV 2020,

81 Dec 12, 2022