This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning (NeurIPS 2021, Spotlight).

Related tags

Deep LearningTANS
Overview

NeurIPS 2021 (Spotlight): Task-Adaptive Neural Network Search with Meta-Contrastive Learning

This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning. Accepted to NeurIPS 2021 (Spotlight).

@inproceedings{jeong2021task,
    title     = {Task-Adaptive Neural Network Search with Meta-Contrastive Learning},
    author    = {Jeong, Wonyong and Lee, Hayeon and Park, Geon and Hyung, Eunyoung and Baek, Jinheon and Hwang, Sung Ju},
    booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
    year      = {2021}
} 

Overview

Most conventional Neural Architecture Search (NAS) approaches are limited in that they only generate architectures without searching for the optimal parameters. While some NAS methods handle this issue by utilizing a supernet trained on a large-scale dataset such as ImageNet, they may be suboptimal if the target tasks are highly dissimilar from the dataset the supernet is trained on. To address such limitations, we introduce a novel problem of Neural Network Search (NNS), whose goal is to search for the optimal pretrained network for a novel dataset and constraints (e.g. number of parameters), from a model zoo. Then, we propose a novel framework to tackle the problem, namely Task-Adaptive Neural Network Search (TANS). Given a model-zoo that consists of network pretrained on diverse datasets, we use a novel amortized meta-learning framework to learn a cross-modal latent space with contrastive loss, to maximize the similarity between a dataset and a high-performing network on it, and minimize the similarity between irrelevant dataset-network pairs. We validate the effectiveness and efficiency of our method on ten real-world datasets, against existing NAS/AutoML baselines. The results show that our method instantly retrieves networks that outperform models obtained with the baselines with significantly fewer training steps to reach the target performance, thus minimizing the total cost of obtaining a task-optimal network.

Prerequisites

  • Python 3.8 (Anaconda)
  • PyTorch 1.8.1
  • CUDA 10.2

Environmental Setup

Please install packages thorugh requirements.txt after creating your own environment with python 3.8.x.

$ conda create --name ENV_NAME python=3.8
$ conda activate ENV_NAME
$ conda install pytorch==1.8.1 torchvision cudatoolkit=10.2 -c pytorch
$ pip install --upgrade pip
$ pip install -r requirements.txt

Preparation

We provide our model-zoo consisting of 14K pretrained models on various Kaggle datasets. We also share the full raw datasets collected from Kaggle as well as their processed versions of datasets for meta-training and meta-test in our learning framework. Except for the raw datasets, all the processed files are required to perform the cross model retrieval learning and meta-testing on unseen datasets. Please download following files before training or testing. (Due to the heavy file size, some files will be updated by Oct. 28th. Sorry for the inconvenience).

No. File Name Description Extension Size Download
1 p_mod_zoo Processed 14K Model-Zoo pt 91.9Mb Link
2 ofa_nets Pretrained OFA Supernets zip - Pending
3 raw_m_train Raw Meta-Training Datasets zip - Pending
4 raw_m_test Raw Meta-Test Datasets zip - Pending
5 p_m_train Processed Meta-Training Files pt 69Mb Link
6 p_m_test Processed Meta-Test Files zip 11.6Gb Link

After download, specify their location on following arguments:

  • data-path: 5 and 6 should be placed. 6 must be unzipped.
  • model-zoo: path where 1 should be located. Please give full path to the file. i.e. path/to/p_mod_zoo.pt
  • model-zoo-raw: path where 2 should be placed and unzipped (required for meta-test experiments)

Learning the Cross Modal Retrieval Space

Please use following command to learn the cross modal space. Keep in mind that correct model-zoo and data-path are required. Forbase-path, this path is for storing training outcomes, such as resutls, logs, the cross modal embeddings, etc.

$ python3 main.py --gpu $1 \
                  --mode train \
                  --batch-size 140 \
                  --n-epochs 10000 \
                  --base-path path/for/storing/outcomes/\
                  --data-path path/to/processed/dataset/is/stored/\
                  --model-zoo path/to/model_zoo.pt\
                  --seed 777 

You can also simply run a script file after updating the paths.

$ cd scripts
$ sh train.sh GPU_NO

Meta-Test Experiment

You can use following command for testing the cross-modal retrieval performance on unseen meta-test datasets. In this experiment, load-path which is the base-path of the cross modal space that you previously built and model-zoo-raw which is path for the OFA supernets pretrained on meta-training datasets are required.

$ python3 ../main.py --gpu $1 \
                     --mode test \
                     --n-retrievals 10\
                     --n-eps-finetuning 50\
                     --batch-size 32\
                     --load-path path/to/outcomes/stored/\
                     --data-path path/to/processed/dataset/is/stored/\
                     --model-zoo path/to/model_zoo.pt\
                     --model-zoo-raw path/to/pretrained/ofa/models/\
                     --base-path path/for/storing/outcomes/\
                     --seed 777

You can also simply run a script file after updating the paths.

$ cd scripts
$ sh test.sh GPU_NO
Owner
Wonyong Jeong
Ph.D. Candidate @ KAIST AI
Wonyong Jeong
Code of our paper "Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning"

CCOP Code of our paper Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning Requirement Install OpenSelfSup Install Detectron2

Chenhongyi Yang 21 Dec 13, 2022
LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs

LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs This is the code for the LERP. Dataset The dataset used is MI

5 Jun 18, 2022
Recurrent Scale Approximation (RSA) for Object Detection

Recurrent Scale Approximation (RSA) for Object Detection Codebase for Recurrent Scale Approximation for Object Detection in CNN published at ICCV 2017

Yu Liu (Louis) 239 Dec 28, 2022
Vpw analyzer - A visual J1850 VPW analyzer written in Python

VPW Analyzer A visual J1850 VPW analyzer written in Python Requires Tkinter, Pan

7 May 01, 2022
Connecting Java/ImgLib2 + Python/NumPy

imglyb imglyb aims at connecting two worlds that have been seperated for too long: Python with numpy Java with ImgLib2 imglyb uses jpype to access num

ImgLib2 29 Dec 21, 2022
Codes for TS-CAM: Token Semantic Coupled Attention Map for Weakly Supervised Object Localization.

TS-CAM: Token Semantic Coupled Attention Map for Weakly SupervisedObject Localization This is the official implementaion of paper TS-CAM: Token Semant

vasgaowei 112 Jan 02, 2023
Image-based Navigation in Real-World Environments via Multiple Mid-level Representations: Fusion Models Benchmark and Efficient Evaluation

Image-based Navigation in Real-World Environments via Multiple Mid-level Representations: Fusion Models Benchmark and Efficient Evaluation This reposi

First Person Vision @ Image Processing Laboratory - University of Catania 1 Aug 21, 2022
[SDM 2022] Towards Similarity-Aware Time-Series Classification

SimTSC This is the PyTorch implementation of SDM2022 paper Towards Similarity-Aware Time-Series Classification. We propose Similarity-Aware Time-Serie

Daochen Zha 49 Dec 27, 2022
Code and data form the paper BERT Got a Date: Introducing Transformers to Temporal Tagging

BERT Got a Date: Introducing Transformers to Temporal Tagging Satya Almasian*, Dennis Aumiller*, and Michael Gertz Heidelberg University Contact us vi

54 Dec 04, 2022
Official PyTorch Implementation of "Self-supervised Auxiliary Learning with Meta-paths for Heterogeneous Graphs". NeurIPS 2020.

Self-supervised Auxiliary Learning with Meta-paths for Heterogeneous Graphs This repository is the implementation of SELAR. Dasol Hwang* , Jinyoung Pa

MLV Lab (Machine Learning and Vision Lab at Korea University) 48 Nov 09, 2022
Consensus Learning from Heterogeneous Objectives for One-Class Collaborative Filtering

Consensus Learning from Heterogeneous Objectives for One-Class Collaborative Filtering This repository provides the source code of "Consensus Learning

SeongKu-Kang 6 Apr 29, 2022
PyTorch implementation of our ICCV 2021 paper Intrinsic-Extrinsic Preserved GANs for Unsupervised 3D Pose Transfer.

Unsupervised_IEPGAN This is the PyTorch implementation of our ICCV 2021 paper Intrinsic-Extrinsic Preserved GANs for Unsupervised 3D Pose Transfer. Ha

25 Oct 26, 2022
Implementation of the paper: "SinGAN: Learning a Generative Model from a Single Natural Image"

SinGAN This is an unofficial implementation of SinGAN from someone who's been sitting right next to SinGAN's creator for almost five years. Please ref

35 Nov 10, 2022
Easy-to-use,Modular and Extendible package of deep-learning based CTR models .

DeepCTR DeepCTR is a Easy-to-use,Modular and Extendible package of deep-learning based CTR models along with lots of core components layers which can

浅梦 6.6k Jan 08, 2023
Weakly Supervised Segmentation with Tensorflow. Implements instance segmentation as described in Simple Does It: Weakly Supervised Instance and Semantic Segmentation, by Khoreva et al. (CVPR 2017).

Weakly Supervised Segmentation with TensorFlow This repo contains a TensorFlow implementation of weakly supervised instance segmentation as described

Phil Ferriere 220 Dec 13, 2022
Pytorch version of VidLanKD: Improving Language Understanding viaVideo-Distilled Knowledge Transfer

VidLanKD Implementation of VidLanKD: Improving Language Understanding via Video-Distilled Knowledge Transfer by Zineng Tang, Jaemin Cho, Hao Tan, Mohi

Zineng Tang 54 Dec 20, 2022
SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021)

SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021) PyTorch implementation of SnapMix | paper Method Overview Cite

DavidHuang 126 Dec 30, 2022
JDet is Object Detection Framework based on Jittor.

JDet is Object Detection Framework based on Jittor.

135 Dec 14, 2022
Analysis code and Latex source of the manuscript describing the conditional permutation test of confounding bias in predictive modelling.

Git repositoty of the manuscript entitled Statistical quantification of confounding bias in predictive modelling by Tamas Spisak The manuscript descri

PNI - Predictive Neuroimaging Lab, University Hospital Essen, Germany 0 Nov 22, 2021
DeepProbLog is an extension of ProbLog that integrates Probabilistic Logic Programming with deep learning by introducing the neural predicate.

DeepProbLog DeepProbLog is an extension of ProbLog that integrates Probabilistic Logic Programming with deep learning by introducing the neural predic

KU Leuven Machine Learning Research Group 94 Dec 18, 2022