[IJCAI-2021] A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

Overview

DataFree

A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

Authors: Gongfan Fang, Jie Song, Xinchao Wang, Chengchao Shen, Xingen Wang, Mingli Song

CMI (this work) DeepInv
ZSKT DFQ

Results

1. CIFAR-10

Method resnet-34
resnet-18
vgg-11
resnet-18
wrn-40-2
wrn-16-1
wrn-40-2
wrn-40-1
wrn-40-2
wrn-16-2
T. Scratch 95.70 92.25 94.87 94.87 94.87
S. Scratch 95.20 95.20 91.12 93.94 93.95
DAFL 92.22 81.10 65.71 81.33 81.55
ZSKT 93.32 89.46 83.74 86.07 89.66
DeepInv 93.26 90.36 83.04 86.85 89.72
DFQ 94.61 90.84 86.14 91.69 92.01
CMI 94.84 91.13 90.01 92.78 92.52

2. CIFAR-100

Method resnet-34
resnet-18
vgg-11
resnet-18
wrn-40-2
wrn-16-1
wrn-40-2
wrn-40-1
wrn-40-2
wrn-16-2
T. Scratch 78.05 71.32 75.83 75.83 75.83
S. Scratch 77.10 77.01 65.31 72.19 73.56
DAFL 74.47 57.29 22.50 34.66 40.00
ZSKT 67.74 34.72 30.15 29.73 28.44
DeepInv 61.32 54.13 53.77 61.33 61.34
DFQ 77.01 68.32 54.77 62.92 59.01
CMI 77.04 70.56 57.91 68.88 68.75

Quick Start

1. Visualize the inverted samples

Results will be saved as checkpoints/datafree-cmi/synthetic-cmi_for_vis.png

bash scripts/cmi/cmi_cifar10_for_vis.sh

2. Reproduce our results

Note: This repo was refactored from our experimental code and is still under development. I'm struggling to find the appropriate hyperparams for every methods (°ー°〃). So far, we only provide the hyperparameters to reproduce CIFAR-10 results for wrn-40-2 => wrn-16-1. You may need to tune the hyper-parameters for other models and datasets. More resources will be uploaded in the future update.

To reproduce our results, please download pre-trained teacher models from Dropbox-Models (266 MB) and extract them as checkpoints/pretrained. Also a pre-inverted data set with ~50k samples is available for wrn-40-2 teacher on CIFAR-10. You can download it from Dropbox-Data (133 MB) and extract them to run/cmi-preinverted-wrn402/.

  • Non-adversarial CMI: you can train a student model on inverted data directly. It should reach the accuracy of ~87.38% on CIFAR-10 as reported in Figure 3.

    bash scripts/cmi/nonadv_cmi_cifar10_wrn402_wrn161.sh
    
  • Adversarial CMI: or you can apply the adversarial distillation based on the pre-inverted data, where ~10k (256x40) new samples will be generated to improve the student. It should reach the accuracy of ~90.01% on CIFAR-10 as reported in Table 1.

    bash scripts/cmi/adv_cmi_cifar10_wrn402_wrn161.sh
    
  • Scratch CMI: It is OK to run the cmi algorithm wihout any pre-inverted data, but the student may overfit to early samples due to the limited data amount. It should reach the accuracy of ~88.82% on CIFAR-10, slightly worse than our reported results (90.01%).

    bash scripts/cmi/scratch_cmi_cifar10_wrn402_wrn161.sh
    

3. Scratch training

python train_scratch.py --model wrn40_2 --dataset cifar10 --batch-size 256 --lr 0.1 --epoch 200 --gpu 0

4. Vanilla KD

# KD with original training data (beta>0 to use hard targets)
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar10 --beta 0.1 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

# KD with unlabeled data
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar100 --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

# KD with unlabeled data from a specified folder
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set run/cmi --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

5. Data-free KD

bash scripts/xxx/xxx.sh # e.g. scripts/zskt/zskt_cifar10_wrn402_wrn161.sh

Hyper-parameters used by different methods:

Method adv bn oh balance act cr GAN Example
DAFL - - - scripts/dafl_cifar10.sh
ZSKT - - - - - scripts/zskt_cifar10.sh
DeepInv - - - - scripts/deepinv_cifar10.sh
DFQ - - scripts/dfq_cifar10.sh
CMI - - scripts/cmi_cifar10_scratch.sh

4. Use your models/datasets

You can register your models and datasets in registry.py by modifying NORMALIZE_DICT, MODEL_DICT and get_dataset. Then you can run the above commands to train your own models. As DAFL requires intermediate features from the penultimate layer, your model should accept an return_features=True parameter and return a (logits, features) tuple for DAFL.

5. Implement your algorithms

Your algorithms should inherent datafree.synthesis.BaseSynthesizer to implement two interfaces: 1) BaseSynthesizer.synthesize takes several steps to craft new samples and return an image dict for visualization; 2) BaseSynthesizer.sample fetches a batch of training data for KD.

Citation

If you found this work useful for your research, please cite our paper:

@misc{fang2021contrastive,
      title={Contrastive Model Inversion for Data-Free Knowledge Distillation}, 
      author={Gongfan Fang and Jie Song and Xinchao Wang and Chengchao Shen and Xingen Wang and Mingli Song},
      year={2021},
      eprint={2105.08584},
      archivePrefix={arXiv},
      primaryClass={cs.AI}
}

Reference

Owner
ZJU-VIPA
Laboratory of Visual Intelligence and Pattern Analysis
ZJU-VIPA
Pyramid R-CNN: Towards Better Performance and Adaptability for 3D Object Detection

Pyramid R-CNN: Towards Better Performance and Adaptability for 3D Object Detection

61 Jan 07, 2023
A simple python module to generate anchor (aka default/prior) boxes for object detection tasks.

PyBx WIP A simple python module to generate anchor (aka default/prior) boxes for object detection tasks. Calculated anchor boxes are returned as ndarr

thatgeeman 4 Dec 15, 2022
Code for Paper: Self-supervised Learning of Motion Capture

Self-supervised Learning of Motion Capture This is code for the paper: Hsiao-Yu Fish Tung, Hsiao-Wei Tung, Ersin Yumer, Katerina Fragkiadaki, Self-sup

Hsiao-Yu Fish Tung 87 Jul 25, 2022
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma 🔥 News 2021-10

Jingtao Zhan 99 Dec 27, 2022
LAnguage Model Analysis

LAMA: LAnguage Model Analysis LAMA is a probe for analyzing the factual and commonsense knowledge contained in pretrained language models. The dataset

Meta Research 960 Jan 08, 2023
Degree-Quant: Quantization-Aware Training for Graph Neural Networks.

Degree-Quant This repo provides a clean re-implementation of the code associated with the paper Degree-Quant: Quantization-Aware Training for Graph Ne

35 Oct 07, 2022
Code for our paper "Sematic Representation for Dialogue Modeling" in ACL2021

AMR-Dialogue An implementation for paper "Semantic Representation for Dialogue Modeling". You may find our paper here. Requirements python 3.6 pytorch

xfbai 45 Dec 26, 2022
Chinese license plate recognition

AgentCLPR 简介 一个基于 ONNXRuntime、AgentOCR 和 License-Plate-Detector 项目开发的中国车牌检测识别系统。 车牌识别效果 支持多种车牌的检测和识别(其中单层车牌识别效果较好): 单层车牌: [[[[373, 282], [69, 284],

AgentMaker 26 Dec 25, 2022
Pure python PEMDAS expression solver without using built-in eval function

pypemdas Pure python PEMDAS expression solver without using built-in eval function. Supports nested parenthesis. Supported operators: + - * / ^ Exampl

1 Dec 22, 2021
Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

causal-bald | Abstract | Installation | Example | Citation | Reproducing Results DUE An implementation of the methods presented in Causal-BALD: Deep B

OATML 13 Oct 07, 2022
PyTorch Implementation of ECCV 2020 Spotlight TuiGAN: Learning Versatile Image-to-Image Translation with Two Unpaired Images

TuiGAN-PyTorch Official PyTorch Implementation of "TuiGAN: Learning Versatile Image-to-Image Translation with Two Unpaired Images" (ECCV 2020 Spotligh

181 Dec 09, 2022
Cross View SLAM

Cross View SLAM This is the associated code and dataset repository for our paper I. D. Miller et al., "Any Way You Look at It: Semantic Crossview Loca

Ian D. Miller 99 Dec 09, 2022
Source code for "OmniPhotos: Casual 360° VR Photography"

OmniPhotos: Casual 360° VR Photography Project Page | Video | Paper | Demo | Data This repository contains the source code for creating and viewing Om

Christian Richardt 144 Dec 30, 2022
Official code implementation for "Personalized Federated Learning using Hypernetworks"

Personalized Federated Learning using Hypernetworks This is an official implementation of Personalized Federated Learning using Hypernetworks paper. [

Aviv Shamsian 121 Dec 25, 2022
ADB-IP-ROTATION - Use your mobile phone to gain a temporary IP address using ADB and data tethering

ADB IP ROTATE This an Python script based on Android Debug Bridge (adb) shell sc

Dor Bismuth 2 Jul 12, 2022
A Python implementation of global optimization with gaussian processes.

Bayesian Optimization Pure Python implementation of bayesian global optimization with gaussian processes. PyPI (pip): $ pip install bayesian-optimizat

fernando 6.5k Jan 02, 2023
CFNet: Cascade and Fused Cost Volume for Robust Stereo Matching(CVPR2021)

CFNet(CVPR 2021) This is the implementation of the paper CFNet: Cascade and Fused Cost Volume for Robust Stereo Matching, CVPR 2021, Zhelun Shen, Yuch

106 Dec 28, 2022
Implementation of Barlow Twins paper

barlowtwins PyTorch Implementation of Barlow Twins paper: Barlow Twins: Self-Supervised Learning via Redundancy Reduction This is currently a work in

IgorSusmelj 86 Dec 20, 2022
Tooling for converting STAC metadata to ODC data model

手语识别 0、使用到的模型 (1). openpose,作者:CMU-Perceptual-Computing-Lab https://github.com/CMU-Perceptual-Computing-Lab/openpose (2). 图像分类classification,作者:Bubbl

Open Data Cube 65 Dec 20, 2022
AITUS - An atomatic notr maker for CYTUS

AITUS an automatic note maker for CYTUS. 利用AI根据指定乐曲生成CYTUS游戏谱面。 效果展示:https://www

GradiusTwinbee 6 Feb 24, 2022