LSHFM.detection
This is the PyTorch source code for Distilling Knowledge by Mimicking Features. And this project contains code for object detection with mimicking features. For image classification, please visit LSHFM.classification.
dependence
- python
- pytorch 1.7.1
- torchvision 0.8.2
Prepare the dataset
Please prepare the COCO and VOC datasets by youself. Then you need to fix the get_data_path
function in src/dataset/coco_utils.py
and src/dataset/voc_utils.py
.
Run
You can run the experiments by
PORT=4444 bash experiments/[script name].sh 0,1,2,3
the training set contains VOC2007 trainval and VOC2012 trainval, while the testing set is VOC2007 test.
We train all models by 24 epochs while the learning rate decays at the 18th and 22th epoch.
Faster R-CNN
Before you run the KD experiments, please make sure the teacher model weight have been saved in pretrained
. You can first run ResNet101 baseline
and VGG16 baseline
to train the teacher model, and then move the model to pretrained
and edit --teacher-ckpt
in the training shell scripts. You can also download voc0712_fasterrcnn_r101_83.6 and voc0712_fasterrcnn_vgg16fpn_79.0 directly, and move them to pretrained
.
-
ResNet101 baseline: voc0712_fasterrcnn_r101_baseline.sh
-
ResNet50 baseline: voc0712_fasterrcnn_r50_baseline.sh
-
[email protected] LSHL2: voc0712_fasterrcnn_r50_r101_lshl2.sh
-
VGG16 baseline: voc0712_fasterrcnn_vgg11fpn_baseline.sh
-
VGG11 baseline: voc0712_fasterrcnn_vgg16fpn_baseline.sh
-
[email protected] L2: voc0712_fasterrcnn_vgg11fpn_vgg16fpn_l2.sh
-
[email protected] LSH: voc0712_fasterrcnn_vgg11fpn_vgg16fpn_lsh.sh
-
[email protected] LSHL2: voc0712_fasterrcnn_vgg11fpn_vgg16fpn_lshl2.sh
[email protected] | [email protected] | |
---|---|---|
Teacher | 83.6 | 79.0 |
Student | 82.0 | 75.1 |
L2 | 83.0 | 76.8 |
LSH | 82.6 | 76.7 |
LSHL2 | 83.0 | 77.2 |
RetinaNet
As mentioned in Faster R-CNN, please make sure there are teacher models in pretrained
. You can download the teacher models in voc0712_retinanet_r101_83.0.ckpt and voc0712_retinanet_vgg16fpn_76.6.ckpt.
-
ResNet101 baseline: voc0712_retinanet_r101_baseline.sh
-
ResNet50 baseline: voc0712_retinanet_r50_baseline.sh
-
[email protected] LSHL2: voc0712_retinanet_r50_r101_lshl2.sh
-
VGG16 baseline: voc0712_retinanet_vgg11fpn_baseline.sh
-
VGG11 baseline: voc0712_retinanet_vgg16fpn_baseline.sh
-
[email protected] L2: voc0712_retinanet_vgg11fpn_vgg16fpn_l2.sh
-
[email protected] LSHL2: voc0712_retinanet_vgg11fpn_vgg16fpn_lshl2.sh
[email protected] | [email protected] | |
---|---|---|
Teacher | 83.0 | 76.6 |
Student | 82.5 | 73.2 |
L2 | 82.6 | 74.8 |
LSHL2 | 83.0 | 75.2 |
We find that it is easy to get NaN loss when training by LSH KD.
visualize
visualize the ground truth label
python src/visual.py --dataset voc07 --idx 1 --gt
visualize the model prediction
python src/visual.py --dataset voc07 --idx 2 --model fasterrcnn_resnet50_fpn --checkpoint results/voc0712/fasterrcnn_resnet50_fpn/2020-12-11_20\:14\:09/model_13.pth
Citing this repository
If you find this code useful in your research, please consider citing us:
@article{LSHFM,
title={Distilling knowledge by mimicking features},
author={Wang, Guo-Hua and Ge, Yifan and Wu, Jianxin},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2021},
}
Acknowledgement
This project is based on https://github.com/pytorch/vision/tree/master/references/detection. This project aims at object detection, so I remove the code about segmentation and keypoint detection.