This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning (NeurIPS 2021, Spotlight).

Related tags

Deep LearningTANS
Overview

NeurIPS 2021 (Spotlight): Task-Adaptive Neural Network Search with Meta-Contrastive Learning

This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning. Accepted to NeurIPS 2021 (Spotlight).

@inproceedings{jeong2021task,
    title     = {Task-Adaptive Neural Network Search with Meta-Contrastive Learning},
    author    = {Jeong, Wonyong and Lee, Hayeon and Park, Geon and Hyung, Eunyoung and Baek, Jinheon and Hwang, Sung Ju},
    booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
    year      = {2021}
} 

Overview

Most conventional Neural Architecture Search (NAS) approaches are limited in that they only generate architectures without searching for the optimal parameters. While some NAS methods handle this issue by utilizing a supernet trained on a large-scale dataset such as ImageNet, they may be suboptimal if the target tasks are highly dissimilar from the dataset the supernet is trained on. To address such limitations, we introduce a novel problem of Neural Network Search (NNS), whose goal is to search for the optimal pretrained network for a novel dataset and constraints (e.g. number of parameters), from a model zoo. Then, we propose a novel framework to tackle the problem, namely Task-Adaptive Neural Network Search (TANS). Given a model-zoo that consists of network pretrained on diverse datasets, we use a novel amortized meta-learning framework to learn a cross-modal latent space with contrastive loss, to maximize the similarity between a dataset and a high-performing network on it, and minimize the similarity between irrelevant dataset-network pairs. We validate the effectiveness and efficiency of our method on ten real-world datasets, against existing NAS/AutoML baselines. The results show that our method instantly retrieves networks that outperform models obtained with the baselines with significantly fewer training steps to reach the target performance, thus minimizing the total cost of obtaining a task-optimal network.

Prerequisites

  • Python 3.8 (Anaconda)
  • PyTorch 1.8.1
  • CUDA 10.2

Environmental Setup

Please install packages thorugh requirements.txt after creating your own environment with python 3.8.x.

$ conda create --name ENV_NAME python=3.8
$ conda activate ENV_NAME
$ conda install pytorch==1.8.1 torchvision cudatoolkit=10.2 -c pytorch
$ pip install --upgrade pip
$ pip install -r requirements.txt

Preparation

We provide our model-zoo consisting of 14K pretrained models on various Kaggle datasets. We also share the full raw datasets collected from Kaggle as well as their processed versions of datasets for meta-training and meta-test in our learning framework. Except for the raw datasets, all the processed files are required to perform the cross model retrieval learning and meta-testing on unseen datasets. Please download following files before training or testing. (Due to the heavy file size, some files will be updated by Oct. 28th. Sorry for the inconvenience).

No. File Name Description Extension Size Download
1 p_mod_zoo Processed 14K Model-Zoo pt 91.9Mb Link
2 ofa_nets Pretrained OFA Supernets zip - Pending
3 raw_m_train Raw Meta-Training Datasets zip - Pending
4 raw_m_test Raw Meta-Test Datasets zip - Pending
5 p_m_train Processed Meta-Training Files pt 69Mb Link
6 p_m_test Processed Meta-Test Files zip 11.6Gb Link

After download, specify their location on following arguments:

  • data-path: 5 and 6 should be placed. 6 must be unzipped.
  • model-zoo: path where 1 should be located. Please give full path to the file. i.e. path/to/p_mod_zoo.pt
  • model-zoo-raw: path where 2 should be placed and unzipped (required for meta-test experiments)

Learning the Cross Modal Retrieval Space

Please use following command to learn the cross modal space. Keep in mind that correct model-zoo and data-path are required. Forbase-path, this path is for storing training outcomes, such as resutls, logs, the cross modal embeddings, etc.

$ python3 main.py --gpu $1 \
                  --mode train \
                  --batch-size 140 \
                  --n-epochs 10000 \
                  --base-path path/for/storing/outcomes/\
                  --data-path path/to/processed/dataset/is/stored/\
                  --model-zoo path/to/model_zoo.pt\
                  --seed 777 

You can also simply run a script file after updating the paths.

$ cd scripts
$ sh train.sh GPU_NO

Meta-Test Experiment

You can use following command for testing the cross-modal retrieval performance on unseen meta-test datasets. In this experiment, load-path which is the base-path of the cross modal space that you previously built and model-zoo-raw which is path for the OFA supernets pretrained on meta-training datasets are required.

$ python3 ../main.py --gpu $1 \
                     --mode test \
                     --n-retrievals 10\
                     --n-eps-finetuning 50\
                     --batch-size 32\
                     --load-path path/to/outcomes/stored/\
                     --data-path path/to/processed/dataset/is/stored/\
                     --model-zoo path/to/model_zoo.pt\
                     --model-zoo-raw path/to/pretrained/ofa/models/\
                     --base-path path/for/storing/outcomes/\
                     --seed 777

You can also simply run a script file after updating the paths.

$ cd scripts
$ sh test.sh GPU_NO
Owner
Wonyong Jeong
Ph.D. Candidate @ KAIST AI
Wonyong Jeong
Jetson Nano-based smart camera system that measures crowd face mask usage in real-time.

MaskCam MaskCam is a prototype reference design for a Jetson Nano-based smart camera system that measures crowd face mask usage in real-time, with all

BDTI 212 Dec 29, 2022
A graphical Semi-automatic annotation tool based on labelImg and Yolov5

💕YOLOV5 semi-automatic annotation tool (Based on labelImg)

EricFang 247 Jan 05, 2023
Effect of Deep Transfer and Multi task Learning on Sperm Abnormality Detection

Effect of Deep Transfer and Multi task Learning on Sperm Abnormality Detection Introduction This repository includes codes and models of "Effect of De

Amir Abbasi 5 Sep 05, 2022
Deep Markov Factor Analysis (NeurIPS2021)

Deep Markov Factor Analysis (DMFA) Codes and experiments for deep Markov factor analysis (DMFA) model accepted for publication at NeurIPS2021: A. Farn

Sarah Ostadabbas 2 Dec 16, 2022
SAAVN - Sound Adversarial Audio-Visual Navigation,ICLR2022 (In PyTorch)

SAAVN SAAVN Code release for paper "Sound Adversarial Audio-Visual Navigation,IC

YinfengYu 10 Aug 30, 2022
Deeply Supervised, Layer-wise Prediction-aware (DSLP) Transformer for Non-autoregressive Neural Machine Translation

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision Training Efficiency We show the training efficiency of our DSLP model b

Chenyang Huang 36 Oct 31, 2022
Tree Nested PyTorch Tensor Lib

DI-treetensor treetensor is a generalized tree-based tensor structure mainly developed by OpenDILab Contributors. Almost all the operation can be supp

OpenDILab 167 Dec 29, 2022
Auto-Lama combines object detection and image inpainting to automate object removals

Auto-Lama Auto-Lama combines object detection and image inpainting to automate object removals. It is build on top of DE:TR from Facebook Research and

44 Dec 09, 2022
Data and analysis code for an MS on SK VOC genomes phenotyping/neutralisation assays

Description Summary of phylogenomic methods and analyses used in "Immunogenicity of convalescent and vaccinated sera against clinical isolates of ance

Finlay Maguire 1 Jan 06, 2022
Bulk2Space is a spatial deconvolution method based on deep learning frameworks

Bulk2Space Spatially resolved single-cell deconvolution of bulk transcriptomes using Bulk2Space Bulk2Space is a spatial deconvolution method based on

Dr. FAN, Xiaohui 60 Dec 27, 2022
Machine learning notebooks in different subjects optimized to run in google collaboratory

Notebooks Name Description Category Link Training pix2pix This notebook shows a simple pipeline for training pix2pix on a simple dataset. Most of the

Zaid Alyafeai 363 Dec 06, 2022
A Python 3 package for state-of-the-art statistical dimension reduction methods

direpack: a Python 3 library for state-of-the-art statistical dimension reduction techniques This package delivers a scikit-learn compatible Python 3

Sven Serneels 32 Dec 14, 2022
Code and data for paper "Deep Photo Style Transfer"

deep-photo-styletransfer Code and data for paper "Deep Photo Style Transfer" Disclaimer This software is published for academic and non-commercial use

Fujun Luan 9.9k Dec 29, 2022
Using modified BiSeNet for face parsing in PyTorch

face-parsing.PyTorch Contents Training Demo References Training Prepare training data: -- download CelebAMask-HQ dataset -- change file path in the pr

zll 1.6k Jan 08, 2023
Codes for the ICCV'21 paper "FREE: Feature Refinement for Generalized Zero-Shot Learning"

FREE This repository contains the reference code for the paper "FREE: Feature Refinement for Generalized Zero-Shot Learning". [arXiv][Paper] 1. Prepar

Shiming Chen 28 Jul 29, 2022
Deep deconfounded recommender (Deep-Deconf) for paper "Deep causal reasoning for recommendations"

Deep Causal Reasoning for Recommender Systems The codes are associated with the following paper: Deep Causal Reasoning for Recommendations, Yaochen Zh

Yaochen Zhu 22 Oct 15, 2022
MBPO (paper: When to trust your model: Model-based policy optimization) in offline RL settings

offline-MBPO This repository contains the code of a version of model-based RL algorithm MBPO, which is modified to perform in offline RL settings Pape

LxzGordon 1 Oct 24, 2021
EMNLP 2021 Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections

Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections Ruiqi Zhong, Kristy Lee*, Zheng Zhang*, Dan Klein EMN

Ruiqi Zhong 42 Nov 03, 2022
Generic Event Boundary Detection: A Benchmark for Event Segmentation

Generic Event Boundary Detection: A Benchmark for Event Segmentation We release our data annotation & baseline codes for detecting generic event bound

47 Nov 22, 2022
Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021.

Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021. Bobo Xi, Jiaojiao Li, Yunsong Li and Qian Du. Code f

Bobo Xi 7 Nov 03, 2022