AlphaNet Improved Training of Supernet with Alpha-Divergence

Related tags

Deep LearningAlphaNet
Overview

AlphaNet: Improved Training of Supernet with Alpha-Divergence

This repository contains our PyTorch training code, evaluation code and pretrained models for AlphaNet.

PWC

Our implementation is largely based on AttentiveNAS. To reproduce our results, please first download the AttentiveNAS repo, and use our train_alphanet.py for training and test_alphanet.py for testing.

For more details, please see AlphaNet: Improved Training of Supernet with Alpha-Divergence by Dilin Wang, Chengyue Gong, Meng Li, Qiang Liu, Vikas Chandra.

If you find this repo useful in your research, please consider citing our work and AttentiveNAS:

@article{wang2021alphanet,
  title={AlphaNet: Improved Training of Supernet with Alpha-Divergence},
  author={Wang, Dilin and Gong, Chengyue and Li, Meng and Liu, Qiang and Chandra, Vikas},
  journal={arXiv preprint arXiv:2102.07954},
  year={2021}
}

@article{wang2020attentivenas,
  title={AttentiveNAS: Improving Neural Architecture Search via Attentive Sampling},
  author={Wang, Dilin and Li, Meng and Gong, Chengyue and Chandra, Vikas},
  journal={arXiv preprint arXiv:2011.09011},
  year={2020}
}

Evaluation

To reproduce our results:

  • Please first download our pretrained AlphaNet models from a Google Drive path and put the pretrained models under your local folder ./alphanet_data

  • To evaluate our pre-trained AlphaNet models, from AlphaNet-A0 to A6, on ImageNet with a single GPU, please run:

    python test_alphanet.py --config-file ./configs/eval_alphanet_models.yml --model a[0-6]

    Expected results:

    Name MFLOPs Top-1 (%)
    AlphaNet-A0 203 77.87
    AlphaNet-A1 279 78.94
    AlphaNet-A2 317 79.20
    AlphaNet-A3 357 79.41
    AlphaNet-A4 444 80.01
    AlphaNet-A5 (small) 491 80.29
    AlphaNet-A5 (base) 596 80.62
    AlphaNet-A6 709 80.78
  • Additionally, here is our pretrained supernet with KL based inplace-KD and here is our pretrained supernet without inplace-KD.

Training

To train our AlphaNet models from scratch, please run:

python train_alphanet.py --config-file configs/train_alphanet_models.yml --machine-rank ${machine_rank} --num-machines ${num_machines} --dist-url ${dist_url}

We adopt SGD training on 64 GPUs. The mini-batch size is 32 per GPU; all training hyper-parameters are specified in train_alphanet_models.yml.

Evolutionary search

In case you want to search the set of models of your own interest - we provide an example to show how to search the Pareto models for the best FLOPs vs. accuracy tradeoffs in parallel_supernet_evo_search.py; to run this example:

python parallel_supernet_evo_search.py --config-file configs/parallel_supernet_evo_search.yml 

License

AlphaNet is licensed under CC-BY-NC.

Contributing

We actively welcome your pull requests! Please see CONTRIBUTING and CODE_OF_CONDUCT for more info.

Owner
Facebook Research
Facebook Research
Object detection on multiple datasets with an automatically learned unified label space.

Simple multi-dataset detection An object detector trained on multiple large-scale datasets with a unified label space; Winning solution of E

Xingyi Zhou 407 Dec 30, 2022
The Multi-Mission Maximum Likelihood framework (3ML)

PyPi Conda The Multi-Mission Maximum Likelihood framework (3ML) A framework for multi-wavelength/multi-messenger analysis for astronomy/astrophysics.

The Multi-Mission Maximum Likelihood (3ML) 62 Dec 30, 2022
A 3D Dense mapping backend library of SLAM based on taichi-Lang designed for the aerial swarm.

TaichiSLAM This project is a 3D Dense mapping backend library of SLAM based Taichi-Lang, designed for the aerial swarm. Intro Taichi is an efficient d

XuHao 230 Dec 19, 2022
Training Confidence-Calibrated Classifier for Detecting Out-of-Distribution Samples / ICLR 2018

Training Confidence-Calibrated Classifier for Detecting Out-of-Distribution Samples This project is for the paper "Training Confidence-Calibrated Clas

168 Nov 29, 2022
This repository contains the code for the paper in EMNLP 2021: "HRKD: Hierarchical Relational Knowledge Distillation for Cross-domain Language Model Compression".

HRKD: Hierarchical Relational Knowledge Distillation for Cross-domain Language Model Compression This repository contains the code for the paper in EM

Chenhe Dong 2 Mar 24, 2022
Official PyTorch implementation of U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation

U-GAT-IT — Official PyTorch Implementation : Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Imag

Hyeonwoo Kang 2.4k Jan 04, 2023
Official PyTorch implementation of "BlendGAN: Implicitly GAN Blending for Arbitrary Stylized Face Generation" (NeurIPS 2021)

BlendGAN: Implicitly GAN Blending for Arbitrary Stylized Face Generation Official PyTorch implementation of the NeurIPS 2021 paper Mingcong Liu, Qiang

onion 462 Dec 29, 2022
Stock-Prediction - prediction of stock market movements using sentiment analysis and deep learning.

Stock-Prediction- In this project, we aim to enhance the prediction of stock market movements using sentiment analysis and deep learning. We divide th

5 Jan 25, 2022
CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper)

CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper) (Accepted for oral presentation at ACM

Minha Kim 1 Nov 12, 2021
Efficiently Disentangle Causal Representations

Efficiently Disentangle Causal Representations Install dependency pip install -r requirements.txt Main experiments Causality direction prediction cd

4 Apr 01, 2022
A collection of IPython notebooks covering various topics.

ipython-notebooks This repo contains various IPython notebooks I've created to experiment with libraries and work through exercises, and explore subje

John Wittenauer 2.6k Jan 01, 2023
Learning High-Speed Flight in the Wild

Learning High-Speed Flight in the Wild This repo contains the code associated to the paper Learning Agile Flight in the Wild. For more information, pl

Robotics and Perception Group 391 Dec 29, 2022
Animate molecular orbital transitions using Psi4 and Blender

Molecular Orbital Transitions (MOT) Animate molecular orbital transitions using Psi4 and Blender Author: Maximilian Paradiz Dominguez, University of A

3 Feb 01, 2022
Code and data for ACL2021 paper Cross-Lingual Abstractive Summarization with Limited Parallel Resources.

Multi-Task Framework for Cross-Lingual Abstractive Summarization (MCLAS) The code for ACL2021 paper Cross-Lingual Abstractive Summarization with Limit

Yu Bai 43 Nov 07, 2022
Computationally efficient algorithm that identifies boundary points of a point cloud.

BoundaryTest Included are MATLAB and Python packages, each of which implement efficient algorithms for boundary detection and normal vector estimation

6 Dec 09, 2022
SatelliteNeRF - PyTorch-based Neural Radiance Fields adapted to satellite domain

SatelliteNeRF PyTorch-based Neural Radiance Fields adapted to satellite domain.

Kai Zhang 46 Nov 20, 2022
PyTorch Code for the paper "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"

Improving Visual-Semantic Embeddings with Hard Negatives Code for the image-caption retrieval methods from VSE++: Improving Visual-Semantic Embeddings

Fartash Faghri 441 Dec 05, 2022
Code for CVPR2021 paper "Robust Reflection Removal with Reflection-free Flash-only Cues"

Robust Reflection Removal with Reflection-free Flash-only Cues (RFC) Paper | To be released: Project Page | Video | Data Tensorflow implementation for

Chenyang LEI 162 Jan 05, 2023
Privacy-Preserving Machine Learning (PPML) Tutorial Presented at PyConDE 2022

PPML: Machine Learning on Data you cannot see Repository for the tutorial on Privacy-Preserving Machine Learning (PPML) presented at PyConDE 2022 Abst

Valerio Maggio 10 Aug 16, 2022
PoseViz – Multi-person, multi-camera 3D human pose visualization tool built using Mayavi.

PoseViz – 3D Human Pose Visualizer Multi-person, multi-camera 3D human pose visualization tool built using Mayavi. As used in MeTRAbs visualizations.

István Sárándi 79 Dec 30, 2022