CxGrad - Official PyTorch Implementation
Contextual Gradient Scaling for Few-Shot Learning
Sanghyuk Lee, Seunghyun Lee, and Byung Cheol Song
In WACV 2022. (Paper link will be provided soon)
This repository is an official PyTorch implementation for "Contextual Gradient Scaling for Few-Shot Learning" in WACV 2022.
Installation
This code is based on PyTorch. Please make a virtual environment and use it by running the command below:
conda env create --file environment.yml -n CxGrad
conda activate CxGrad
Datasets
We provide instructions to download 4 datasets: miniImageNet
, tieredImageNet
, CUB
, and CIFAR-FS
. Download the datasets you want to use and move them to datasets
.
-
miniImageNet: Download
mini_imagenet_full_size.tar.bz2
from this link, provided in MAML++. Note that by downloading and using the miniImageNet, you accept terms and conditions found inimagenet_license.md
. -
tieredImageNet: Download
tiered_imagenet.tar
from this link. -
CIFAR-FS: Download
cifar100.zip
from this link. The splits and the download link are provided by Bertinetto. -
CUB: Download
CUB_200_2011.tgz
from this link. The classes of each split are randomly chosen. Thus, we provide the splits of our experiments:CUB_split_train.txt
,CUB_split_val.txt
, andCUB_split_test.txt
indatasets/preprocess
. These splits are done by a script written by Chen.
Then, run the command below to preprocess the datasets you downloaded.
python preprocess/preprocess.py --datasets DATASET1 DATASET2 ...
The structure should be like this:
CxGrad
├── datasets
| ├── miniImageNet
| | ├── train
| | ├── val
| | └── test
| |── tieredImageNet
| | ├── train
| | ├── val
| | └── test
| ├── CIFAR-FS
| | ├── train
| | ├── val
| | └── test
| └── CUB
| ├── train
| ├── val
| └── test
├── utils
├── README.md
└── ...
Run experiments
- Change directory to
experiment_scripts
.
Train
- In order to train the model on N-way K-shot miniImageNet classification, run
bash mini_imagenet_Nway_Kshot/CxGrad_4conv.sh GPU_ID
- Otherwise for tieredImageNet, run
bash tiered_imagenet_Nway_Kshot/CxGrad_4conv.sh GPU_ID
Test
- ex) Test on CUB using the model trained on 5-way 5-shot miniImageNet
TEST=1 TEST_DATASET=CUB bash mini_imagenet_5way_5shot/CxGrad_4conv.sh GPU_ID
Citation
To be prepared
Acknowledgment
Thanks to the authors of MAML++ and ALFA, which our work is based on, for their great implementations.