PyTorch implementation of "Continual Learning with Deep Generative Replay", NIPS 2017

Overview

pytorch-deep-generative-replay

PyTorch implementation of Continual Learning with Deep Generative Replay, NIPS 2017

model

Results

Continual Learning on Permutated MNISTs

  • Test precisions without replay (left), with exact replay (middle), and with Deep Generative Replay (right).

Continual Learning on MNIST-SVHN

  • Test precisions without replay (left), with exact replay (middle), and with Deep Generative Replay (right).

  • Generated samples from the scholar trained without any replay (left) and with Deep Generative Replay (right).

  • Training scholar's generator without replay (left) and with Deep Generative Replay (right).

Continual Learning on SVHN-MNIST

  • Test precisions without replay (left), with exact replay (middle), and with Deep Generative Replay (right).

  • Generated samples from the scholar trained without replay (left) and with Deep Generative Replay (right).

  • Training scholar's generator without replay (left) and with Deep Generative Replay (right).

Installation

$ git clone https://github.com/kuc2477/pytorch-deep-generative-replay
$ pip install -r pytorch-deep-generative-replay/requirements.txt

Commands

Usage

$ ./main.py --help
$ usage: PyTorch implementation of Deep Generative Replay [-h]
                                                          [--experiment {permutated-mnist,svhn-mnist,mnist-svhn}]
                                                          [--mnist-permutation-number MNIST_PERMUTATION_NUMBER]
                                                          [--mnist-permutation-seed MNIST_PERMUTATION_SEED]
                                                          --replay-mode
                                                          {exact-replay,generative-replay,none}
                                                          [--generator-z-size GENERATOR_Z_SIZE]
                                                          [--generator-c-channel-size GENERATOR_C_CHANNEL_SIZE]
                                                          [--generator-g-channel-size GENERATOR_G_CHANNEL_SIZE]
                                                          [--solver-depth SOLVER_DEPTH]
                                                          [--solver-reducing-layers SOLVER_REDUCING_LAYERS]
                                                          [--solver-channel-size SOLVER_CHANNEL_SIZE]
                                                          [--generator-c-updates-per-g-update GENERATOR_C_UPDATES_PER_G_UPDATE]
                                                          [--generator-iterations GENERATOR_ITERATIONS]
                                                          [--solver-iterations SOLVER_ITERATIONS]
                                                          [--importance-of-new-task IMPORTANCE_OF_NEW_TASK]
                                                          [--lr LR]
                                                          [--weight-decay WEIGHT_DECAY]
                                                          [--batch-size BATCH_SIZE]
                                                          [--test-size TEST_SIZE]
                                                          [--sample-size SAMPLE_SIZE]
                                                          [--image-log-interval IMAGE_LOG_INTERVAL]
                                                          [--eval-log-interval EVAL_LOG_INTERVAL]
                                                          [--loss-log-interval LOSS_LOG_INTERVAL]
                                                          [--checkpoint-dir CHECKPOINT_DIR]
                                                          [--sample-dir SAMPLE_DIR]
                                                          [--no-gpus]
                                                          (--train | --test)

To Run Full Experiments

# Run a visdom server and conduct full experiments
$ python -m visdom.server &
$ ./run_full_experiments

To Run a Single Experiment

# Run a visdom server and conduct a desired experiment
$ python -m visdom.server &
$ ./main.py --train --experiment=[permutated-mnist|svhn-mnist|mnist-svhn] --replay-mode=[exact-replay|generative-replay|none]

To Generate Images from the learned Scholar

$ # Run the command below and visit the "samples" directory
$ ./main.py --test --experiment=[permutated-mnist|svhn-mnist|mnist-svhn] --replay-mode=[exact-replay|generative-replay|none]

Note

  • I failed to find the supplementary materials that the authors mentioned in the paper to contain the experimental details. Thus, I arbitrarily chose a 4-convolutional-layer CNN as a solver in this project. If you know where I can find the additional materials, please let me know through the project's Github issue.

Reference

Author

Ha Junsoo / @kuc2477 / MIT License

Owner
Junsoo Ha
A graduate student @SNUVL
Junsoo Ha
The final project for "Applying AI to Wearable Device Data" course from "AI for Healthcare" - Udacity.

Motion Compensated Pulse Rate Estimation Overview This project has 2 main parts. Develop a Pulse Rate Algorithm on the given training data. Then Test

Omar Laham 2 Oct 25, 2022
Solving reinforcement learning tasks which require language and vision

Multimodal Reinforcement Learning JAX implementations of the following multimodal reinforcement learning approaches. Dual-coding Episodic Memory from

Henry Prior 31 Feb 26, 2022
Blind visual quality assessment on 360° Video based on progressive learning

Blind visual quality assessment on omnidirectional or 360 video (ProVQA) Blind VQA for 360° Video via Progressively Learning from Pixels, Frames and V

5 Jan 06, 2023
Implementation of Bidirectional Recurrent Independent Mechanisms (Learning to Combine Top-Down and Bottom-Up Signals in Recurrent Neural Networks with Attention over Modules)

BRIMs Bidirectional Recurrent Independent Mechanisms Implementation of the paper Learning to Combine Top-Down and Bottom-Up Signals in Recurrent Neura

Sarthak Mittal 26 May 26, 2022
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

THUML: Machine Learning Group @ THSS 149 Dec 19, 2022
[CVPR 2021] Scan2Cap: Context-aware Dense Captioning in RGB-D Scans

Scan2Cap: Context-aware Dense Captioning in RGB-D Scans Introduction We introduce the task of dense captioning in 3D scans from commodity RGB-D sensor

Dave Z. Chen 79 Nov 07, 2022
The sixth place winning solution (6/220) in 2021 Gaofen Challenge.

SwinTransformer + OBBDet The sixth place winning solution (6/220) in the track of Fine-grained Object Recognition in High-Resolution Optical Images, 2

ming71 46 Dec 02, 2022
Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective

Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective Zhengzhuo Xu, Zenghao Chai, Chun Yuan This is the PyTorch implement

Sincere 16 Dec 15, 2022
A modular, research-friendly framework for high-performance and inference of sequence models at many scales

T5X T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of

Google Research 1.1k Jan 08, 2023
PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

Irhum Shafkat 342 Dec 16, 2022
Stock-history-display - something like a easy yearly review for your stock performance

Stock History Display Available on Heroku: https://stock-history-display.herokua

LiaoJJ 1 Jan 07, 2022
A machine learning project which can detect and predict the skin disease through image recognition.

ML-Project-2021 A machine learning project which can detect and predict the skin disease through image recognition. The dataset used for this is the H

Debshishu Ghosh 1 Jan 13, 2022
Normalizing Flows with a resampled base distribution

Resampling Base Distributions of Normalizing Flows Normalizing flows are a popular class of models for approximating probability distributions. Howeve

Vincent Stimper 24 Nov 03, 2022
Zero-Shot Text-to-Image Generation VQGAN+CLIP Dockerized

VQGAN-CLIP-Docker About Zero-Shot Text-to-Image Generation VQGAN+CLIP Dockerized This is a stripped and minimal dependency repository for running loca

Kevin Costa 73 Sep 11, 2022
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Mayur 119 Nov 24, 2022
An implementation of Deep Forest 2021.2.1.

Deep Forest (DF) 21 DF21 is an implementation of Deep Forest 2021.2.1. It is designed to have the following advantages: Powerful: Better accuracy than

LAMDA Group, Nanjing University 795 Jan 03, 2023
SPT_LSA_ViT - Implementation for Visual Transformer for Small-size Datasets

Vision Transformer for Small-Size Datasets Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song | Paper Inha University Abstract Recently, the Vision

Lee SeungHoon 87 Jan 01, 2023
This is the implementation of GGHL (A General Gaussian Heatmap Labeling for Arbitrary-Oriented Object Detection)

GGHL: A General Gaussian Heatmap Labeling for Arbitrary-Oriented Object Detection This is the implementation of GGHL 👋 👋 👋 [Arxiv] [Google Drive][B

551 Dec 31, 2022
This is the code for the paper "Motion-Focused Contrastive Learning of Video Representations" (ICCV'21).

Motion-Focused Contrastive Learning of Video Representations Introduction This is the code for the paper "Motion-Focused Contrastive Learning of Video

11 Sep 23, 2022
Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness through a Teacher-guided curriculum Learning Approach

Get Fooled for the Right Reason Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness throu

Sowrya Gali 1 Apr 25, 2022