Canonical Capsules: Unsupervised Capsules in Canonical Pose (NeurIPS 2021)
Introduction
This is the official repository for the PyTorch implementation of "Canonical Capsules: Unsupervised Capsules in Canonical Pose" by Weiwei Sun*, Andrea Tagliasacchi*, Boyang Deng, Sara Sabour, Soroosh Yazdani, Geoffrey Hinton, Kwang Moo Yi.
Download links
- Project Website
- PDF (arXiv)
- PDF (github copy)
Citation
⚠️ If you use this source core or data in your research (in any shape or format), we require you to cite our paper as:
@conference{sun2020canonical,
title={Canonical Capsules: Unsupervised Capsules in Canonical Pose},
author={Weiwei Sun and Andrea Tagliasacchi and Boyang Deng and
Sara Sabour and Soroosh Yazdani and Geoffrey Hinton and
Kwang Moo Yi},
booktitle={Neural Information Processing Systems},
year={2021}
}
Requirements
Please install dependencies with the provided environment.yml:
conda env create -f environment.yml
Datasets
-
We use the ShapeNet dataset as in AtlasNetV2: download the data from AtlasNetV2's official repo and convert the downloaded data into h5 files with the provided script (i.e.,
data_utils/ShapeNetLoader.py). -
For faster experimentation, please use our 2D planes dataset, which we generated from ShapeNet (please cite both our paper, as well as ShapeNet if you use this dataset).
Training/testing (2D)
To train the model on 2D planes (training of network takes only 50 epochs, and one epoch takes approximately 2.5 minutes on an NVIDIA GTX 1080 Ti):
./main.py --log_dir=plane_dim2 --indim=2 --scheduler=5
To visualize the decompostion and reconstruction:
./main.py --save_dir=gifs_plane2d --indim=2 --scheduler=5 --mode=vis --pt_file=logs/plane_dim2/checkpoint.pth
Training/testing (3D)
To train the model on the 3D dataset:
./main.py --log_dir=plane_dim3 --indim=3 --cat_id=-1
We test the model with:
./main.py --log_dir=plane_dim3 --indim=3 --cat_id=-1 --mode=test
Note that the option cat_id indicates the category id to be used to load the corresponding h5 files (this look-up table):
| id | category |
|---|---|
| -1 | all |
| 0 | bench |
| 1 | cabinet |
| 2 | car |
| 3 | cellphone |
| 4 | chair |
| 5 | couch |
| 6 | firearm |
| 7 | lamp |
| 8 | monitor |
| 9 | plane |
| 10 | speaker |
| 11 | table |
| 12 | watercraft |
Pre-trained models (3D)
We release the 3D pretrained models for both single categy (airplanes), as well as multi-category (all 13 classes).
Classification
To use our classification script:
python classification.py --data_dir=/path/to/saved/features --feature_type=caca --method_type=svm --use_kpts
