Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"

Overview

Dataset Distillation by Matching Training Trajectories

Project Page | Paper


Teaser image

This repo contains code for training expert trajectories and distilling synthetic data from our Dataset Distillation by Matching Training Trajectories paper (CVPR 2022). Please see our project page for more results.

Dataset Distillation by Matching Training Trajectories
George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A. Efros, Jun-Yan Zhu
CMU, MIT, UC Berkeley
CVPR 2022

The task of "Dataset Distillation" is to learn a small number of synthetic images such that a model trained on this set alone will have similar test performance as a model trained on the full real dataset.

Our method distills the synthetic dataset by directly optimizing the fake images to induce similar network training dynamics as the full, real dataset. We train "student" networks for many iterations on the synthetic data, measure the error in parameter space between the "student" and "expert" networks trained on real data, and back-propagate through all the student network updates to optimize the synthetic pixels.

Wearable ImageNet: Synthesizing Tileable Textures

Teaser image

Instead of treating our synthetic data as individual images, we can instead encourage every random crop (with circular padding) on a larger canvas of pixels to induce a good training trajectory. This results in class-based textures that are continuous around their edges.

Given these tileable textures, we can apply them to areas that require such properties, such as clothing patterns.

Visualizations made using FAB3D

Getting Started

First, download our repo:

git clone https://github.com/GeorgeCazenavette/mtt-distillation.git
cd mtt-distillation

For an express instillation, we include .yaml files.

If you have an RTX 30XX GPU (or newer), run

conda env create -f requirements_11_3.yaml

If you have an RTX 20XX GPU (or older), run

conda env create -f requirements_10_2.yaml

You can then activate your conda environment with

conda activate distillation
Quadro Users Take Note:

torch.nn.DataParallel seems to not work on Quadro A5000 GPUs, and this may extend to other Quadro cards.

If you experience indefinite hanging during training, try running the process with only 1 GPU by prepending CUDA_VISIBLE_DEVICES=0 to the command.

Generating Expert Trajectories

Before doing any distillation, you'll need to generate some expert trajectories using buffer.py

The following command will train 100 ConvNet models on CIFAR-100 with ZCA whitening for 50 epochs each:

python buffer.py --dataset=CIFAR100 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

We used 50 epochs with the default learning rate for all of our experts. Worse (but still interesting) results can be obtained faster through training fewer experts by changing --num_experts. Note that experts need only be trained once and can be re-used for multiple distillation experiments.

Distillation by Matching Training Trajectories

The following command will then use the buffers we just generated to distill CIFAR-100 down to just 1 image per class:

python distill.py --dataset=CIFAR100 --ipc=1 --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

ImageNet

Our method can also distill subsets of ImageNet into low-support synthetic sets.

When generating expert trajectories with buffer.py or distilling the dataset with distill.py, you must designate a named subset of ImageNet with the --subset flag.

For example,

python distill.py --dataset=ImageNet --subset=imagefruit --model=ConvNetD5 --ipc=1 --res=128 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagefruit subset (at 128x128 resolution) into the following 10 images

To register your own ImageNet subset, you can add it to the Config class at the top of utils.py.

Simply create a list with the desired class ID's and add it to the dictionary.

This gist contains a list of all 1k ImageNet classes and their corresponding numbers.

Texture Distillation

You can also use the same set of expert trajectories (except those using ZCA) to distill classes into toroidal textures by simply adding the --texture flag.

For example,

python distill.py --texture --dataset=ImageNet --subset=imagesquawk --model=ConvNetD5 --ipc=1 --res=256 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagesquawk subset (at 256x256 resolution) into the following 10 textures

Acknowledgments

We would like to thank Alexander Li, Assaf Shocher, Gokul Swamy, Kangle Deng, Ruihan Gao, Nupur Kumari, Muyang Li, Gaurav Parmar, Chonghyuk Song, Sheng-Yu Wang, and Bingliang Zhang as well as Simon Lucey's Vision Group at the University of Adelaide for their valuable feedback. This work is supported, in part, by the NSF Graduate Research Fellowship under Grant No. DGE1745016 and grants from J.P. Morgan Chase, IBM, and SAP. Our code is adapted from https://github.com/VICO-UoE/DatasetCondensation

Related Work

  1. Tongzhou Wang et al. "Dataset Distillation", in arXiv preprint 2018
  2. Bo Zhao et al. "Dataset Condensation with Gradient Matching", in ICLR 2020
  3. Bo Zhao and Hakan Bilen. "Dataset Condensation with Differentiable Siamese Augmentation", in ICML 2021
  4. Timothy Nguyen et al. "Dataset Meta-Learning from Kernel Ridge-Regression", in ICLR 2021
  5. Timothy Nguyen et al. "Dataset Distillation with Infinitely Wide Convolutional Networks", in NeurIPS 2021
  6. Bo Zhao and Hakan Bilen. "Dataset Condensation with Distribution Matching", in arXiv preprint 2021
  7. Kai Wang et al. "CAFE: Learning to Condense Dataset by Aligning Features", in CVPR 2022

Reference

If you find our code useful for your research, please cite our paper.

@inproceedings{
cazenavette2022distillation,
title={Dataset Distillation by Matching Training Trajectories},
author={George Cazenavette and Tongzhou Wang and Antonio Torralba and Alexei A. Efros and Jun-Yan Zhu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
Owner
George Cazenavette
Carnegie Mellon University
George Cazenavette
(CVPR 2021) Back-tracing Representative Points for Voting-based 3D Object Detection in Point Clouds

BRNet Introduction This is a release of the code of our paper Back-tracing Representative Points for Voting-based 3D Object Detection in Point Clouds,

86 Oct 05, 2022
PyTorch/TorchScript compiler for NVIDIA GPUs using TensorRT

PyTorch/TorchScript compiler for NVIDIA GPUs using TensorRT

NVIDIA Corporation 1.8k Dec 30, 2022
This is a Python wrapper for TA-LIB based on Cython instead of SWIG.

TA-Lib This is a Python wrapper for TA-LIB based on Cython instead of SWIG. From the homepage: TA-Lib is widely used by trading software developers re

John Benediktsson 7.3k Jan 03, 2023
A collection of Jupyter notebooks to play with NVIDIA's StyleGAN3 and OpenAI's CLIP for a text-based guided image generation.

StyleGAN3 CLIP-based guidance StyleGAN3 + CLIP StyleGAN3 + inversion + CLIP This repo is a collection of Jupyter notebooks made to easily play with St

Eugenio Herrera 176 Dec 30, 2022
A user-friendly research and development tool built to standardize RL competency assessment for custom agents and environments.

Built with ❤️ by Sam Showalter Contents Overview Installation Dependencies Usage Scripts Standard Execution Environment Development Environment Benchm

SRI-AIC 1 Nov 18, 2021
Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)

ASGNet The code is for the paper "Adaptive Prototype Learning and Allocation for Few-Shot Segmentation" (accepted to CVPR 2021) [arxiv] Overview data/

Gen Li 91 Dec 23, 2022
AdelaiDepth is an open source toolbox for monocular depth prediction.

AdelaiDepth is an open source toolbox for monocular depth prediction.

Adelaide Intelligent Machines (AIM) Group 743 Jan 01, 2023
Source-to-Source Debuggable Derivatives in Pure Python

Tangent Tangent is a new, free, and open-source Python library for automatic differentiation. Existing libraries implement automatic differentiation b

Google 2.2k Jan 01, 2023
This repository accompanies the ACM TOIS paper "What can I cook with these ingredients?" - Understanding cooking-related information needs in conversational search

In this repository you find data that has been gathered when conducting in-situ experiments in a conversational cooking setting. These data include tr

6 Sep 22, 2022
Libraries, tools and tasks created and used at DeepMind Robotics.

dm_robotics: Libraries, tools, and tasks created and used for Robotics research at DeepMind. Package overview Package Summary Transformations Rigid bo

DeepMind 273 Jan 06, 2023
Unsupervised Video Interpolation using Cycle Consistency

Unsupervised Video Interpolation using Cycle Consistency Project | Paper | YouTube Unsupervised Video Interpolation using Cycle Consistency Fitsum A.

NVIDIA Corporation 100 Nov 30, 2022
Axel - 3D printed robotic hands and they controll with Raspberry Pi and Arduino combo

Axel It's our graduation project about 3D printed robotic hands and they control

0 Feb 14, 2022
Benchmark for the generalization of 3D machine learning models across different remeshing/samplings of a surface.

Discretization Robust Correspondence Benchmark One challenge of machine learning on 3D surfaces is that there are many different representations/sampl

Nicholas Sharp 10 Sep 30, 2022
50-days-of-Statistics-for-Data-Science - This repository consist of a 50-day program

50-days-of-Statistics-for-Data-Science - This repository consist of a 50-day program. All the statistics required for the complete understanding of data science will be uploaded in this repository.

komal_lamba 22 Dec 09, 2022
[ICLR'19] Trellis Networks for Sequence Modeling

TrellisNet for Sequence Modeling This repository contains the experiments done in paper Trellis Networks for Sequence Modeling by Shaojie Bai, J. Zico

CMU Locus Lab 460 Oct 13, 2022
🏎️ Accelerate training and inference of 🤗 Transformers with easy to use hardware optimization tools

Hugging Face Optimum 🤗 Optimum is an extension of 🤗 Transformers, providing a set of performance optimization tools enabling maximum efficiency to t

Hugging Face 842 Dec 30, 2022
URIE: Universal Image Enhancementfor Visual Recognition in the Wild

URIE: Universal Image Enhancementfor Visual Recognition in the Wild This is the implementation of the paper "URIE: Universal Image Enhancement for Vis

Taeyoung Son 43 Sep 12, 2022
Refactoring dalle-pytorch and taming-transformers for TPU VM

Text-to-Image Translation (DALL-E) for TPU in Pytorch Refactoring Taming Transformers and DALLE-pytorch for TPU VM with Pytorch Lightning Requirements

Kim, Taehoon 61 Nov 07, 2022
Ankou: Guiding Grey-box Fuzzing towards Combinatorial Difference

Ankou Ankou is a source-based grey-box fuzzer. It intends to use a more rich fitness function by going beyond simple branch coverage and considering t

SoftSec Lab 54 Dec 24, 2022
A set of tools for converting a darknet dataset to COCO format working with YOLOX

darknet格式数据→COCO darknet训练数据目录结构(详情参见dataset/darknet): darknet ├── class.names ├── gen_config.data ├── gen_train.txt ├── gen_valid.txt └── images

RapidAI-NG 148 Jan 03, 2023