Code and project page for ICCV 2021 paper "DisUnknown: Distilling Unknown Factors for Disentanglement Learning"


DisUnknown: Distilling Unknown Factors for Disentanglement Learning

See introduction on our project page


  • PyTorch >= 1.8.0
  • PyYAML, for loading configuration files
  • Optional: h5py, for using the 3D Shapes dataset
  • Optional: Matplotlib, for plotting sample distributions in code space

Preparing Datasets

Dataset classes and default configurations are provided for the following datasets. See below for how to add new datasets, or you can open an issue and the author might consider adding it. Some datasets need to be prepared before using:

$ python prepare_data <dataset_name> --data_path </path/to/dataset>

If the dataset does not have a standard training/test split it will be split randomly. Use the --test_portion <portion> option to set the portion of test samples. Some dataset have additional options.

    • Dataset names are mnist, fashion_mnist, qmnist, svhn.
    • data_path should be the same as those for the built-in dataset classes provided by torchvision.
    • We use the full NIST digit dataset from QMNIST (what = 'nist') and it needs to be split.
    • For SVHN, set include_extra: true in dataset_args in the configuration file (this is the default) to include the extra training images in the training set.
  • 3D Chairs
    • Dataset name is chairs.
    • data_path should be the folder containing the rendered_chairs folder.
    • Needs to be split.
    • You may use --compress to down-sample all images and save them as a NumPy array of PNG-encoded bytes. Use --downsample_size <size> to set image size, default to 128. Note that this does not dictate the training-time image size, which is configured separately. Compressing the images speeds up training only slightly if a multi-processing dataloader is used but makes plotting significantly faster.
    • Unrelated to this work, but the author wants to note that this dataset curiously contains 31 azimuth angles times two altitudes for a total of 62 images for each chair with image id 031 skipped, apparently because 32 was the intended number of azimuth angles but when they rendered the images those angles were generated using numpy.linspace(0, 360, 32), ignoring the fact that 0 and 360 are the same angle, then removed the duplicated images 031 and 063 after they realized the mistake. Beware of off-by-one errors in linspace, especially if it is also circular!
  • 3D shapes
    • Dataset name is 3dshapes.
    • data_path should be the folder containing 3dshapes.h5.
    • Needs to be split.
    • You may use --compress to extract all images and then save them as a NumPy array of PNG-encoded bytes. This is mainly for space-saving: the original dataset, when extracted from HDFS, takes 5.9GB of memory. The re-compressed version takes 2.2GB. Extraction and compression takes about an hour.
  • dSprites
    • Dataset name is dsprites
    • data_path should be the folder containing the .npz file.
    • Needs to be split.
    • This dataset is problematic. I found that orientation 0 and orientation 39 are the same, and presumably that was because similar to 3D Chairs something like linspace(0, 360, 40) was used to generate the angles. So yes, I'm telling you again, beware of off-by-one errors in linspace, especially if it is also circular! Anyway in my dataset class I discarded orientation 39, so there are only 39 different orientations and 3 * 6 * 39 * 32 * 32 = 718848 images.
    • The bigger problem is that each of the three shapes (square, ellipse, heart) has a different symmetry. For hearts, each image uniquely determines an orientation angle; for ellipses, each image has two possible orientation angles; and for squares, each image has four possible orientation angles. They managed to make the dataset so that (apart from orientation 0 and 39 being the same) different orientations correspond to different images because 2 and 4 are not divisors of 39 (which makes me wonder if the off-by-one error was intentional) but the orientation is still conceptually wrong, since if you consider the orientation angles of ellipses modulo 180 or the orientation angles of squares modulo 90, then the orientation class IDs are not ordered in increasing order of orientation angles. Instead the orientation angles of ellipses go around twice in this range and the orientation angles of squares go around four times. To solve this problem, I included an option to set relabel_orientation: true in dataset_args in the configuration file (this is the default) which will cause the orientation of ellipses and squares to be re-labeled in the correct order. Specifically, for ellipses orientation t is re-labeled as (t * 2) % 39 and for squares orientation t is re-labeled as (t * 4) % 39. But still, this causes ellipses to rotate twice as slowly and squares to rotate four times as slowly when the orientation increases, which is still not ideal. When shapes with different symmetries are mixed there is simply no easy solution, and do not expect good results on this dataset if the unknown factor contains the orientation.
    • --compress does the same thing as in 3D Shapes.


To train, use

$ python train --config_file </path/to/config/file> --save_path </path/to/save/folder>

The configuration file is in YAML. See the commented example for explanations. If config_file is omitted, it is expected that save_path already exists and contains config.yaml. Otherwise save_path will be created if it does not exist, and config_file will be copied into it. If save_path already contains a previous training run that has been halted, it will by default resume from the latest checkpoint. --start_from <stage_name> [<iteration>] can be used to choose another restarting point. --start_from stage1 to restart from scratch. Specifying --data_path or --device will override those settings in the configuration file.

Although our goal is to deal with the cases where some factors are labeled and some factors are unknown, it feels wrong not to extrapolate to the cases where all factors are labeled or where all factors are unknown. Wo do allow these, but some parts of our method will become unnecessary and will be discarded accordingly. In particular if all factors are unknown then we just train a VAE in stage I and then a GAN having the same code space in stage II, so you can use this code for just training a GAN. We don't have the myriad of GAN tricks though.

Meaning of Visualization Images

During training, images generated for visualization will be saved in the subfolder samples. test_images.jpg contains images from the test set in even-numbered columns (starting from zero), with odd-numbered columns being empty. The generated images will contain corresponding reconstructed images in even-numbered columns, while each image in odd-numbered columns is generated by combining the unknown code from its left and the labeled code from its right (warp to the next row).

Example test images:

Test images

Example generated images:


Adding a New Dataset

__init__() should accept four positional arguments root, part, labeled_factors, transform in that order, plus any additional keyword arguments that one expects to receive from dataset_args in the configuration file. root is the path to the dataset folder. transform is as usual. part can be train, test or plot, specifying which subset of the dataset to load. The plotting set is generally the same as the test set, but part = 'plot' is passed in so that a smaller plotting set can be used if the test set is too large.

labeled_factors is a list of factor names. __getitem__() should return a tuple (image, labels) where image is the image and labels is a one-dimensional PyTorch tensor of type torch.int64, containing the labels for that image in the order listed in labeled_factors. labels should always be a one-dimensional tensor even if there is only one labeled factor, not a Python int or a zero-dimensional tensor. If labeled_factors is empty then __getitem__() should return image only.

In addition, metadata about the factors should be available in the following properties: nclass should be a list of ints containing the number of classes of each factor, and class_freq should be a list of PyTorch tensors, each being one-dimensional, containing the distribution of classes of each factor in (the current split of) the dataset.

If any preparation is required, implement a static method prepare_data(args) where args is a return value of argparse.ArgumentParser.parse_args(), containing properties data_path and test_portion by default. If additional command-line arguments are needed, implement a static method add_prepare_args(parser) where parser.add_argument() can be called.

Finally add it to the dictionary of recognized datasets in data/

Default configuration should also be created as default_config/datasets/<dataset_name>.yaml. It should at a minimum contain image_size, image_channels and factors. factors has the same syntax as labeled_factors as explained in the example training configuration. It should contain a complete list of all factors. In particular, if the dataset does not include a complete set of labels, there should be a factor called unknown which will become the default unknown factor if labeled_factors is not set in the training configuration.

Any additional settings in the default configuration will override global defaults in default_config/default_config.yaml.

Citing This Work (BibTeX)

  title={DisUnknown: Distilling Unknown Factors for Disentanglement Learning},
  author={Xiang, Sitao and Gu, Yuming and Xiang, Pengda and Chai, Menglei and Li, Hao and Zhao, Yajie and He, Mingming},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
Sitao Xiang
Computer Graphics PhD student at University of Southern California. Twitter: StormRaiser123
Sitao Xiang
Deep Inertial Prediction (DIPr)

Deep Inertial Prediction For more information and context related to this repo, please refer to our website. Getting Started (non Docker) Note: you wi

Arcturus Industries 12 Nov 11, 2022
The first public PyTorch implementation of Attentive Recurrent Comparators

arc-pytorch PyTorch implementation of Attentive Recurrent Comparators by Shyam et al. A blog explaining Attentive Recurrent Comparators Visualizing At

Sanyam Agarwal 150 Oct 14, 2022
Self-Supervised Generative Style Transfer for One-Shot Medical Image Segmentation

Self-Supervised Generative Style Transfer for One-Shot Medical Image Segmentation This repository contains the Pytorch implementation of the proposed

Devavrat Tomar 19 Nov 10, 2022
Dark Finix: All in one hacking framework with almost 100 tools

Dark Finix - Hacking Framework. Dark Finix is a all in one hacking framework wit

Md. Nur habib 2 Feb 18, 2022
Lexical Substitution Framework

LexSubGen Lexical Substitution Framework This repository contains the code to reproduce the results from the paper: Arefyev Nikolay, Sheludko Boris, P

Samsung 37 Sep 15, 2022
Rank 3 : Source code for OPPO 6G Data Generation Challenge

OPPO 6G Data Generation with an E2E Framework Homepage of OPPO 6G Data Generation Challenge Datasets H1_32T4R.mat H2_32T4R.mat Please put the original

Sen Pei 97 Jan 07, 2023
Official implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis

CIPS -- Official Pytorch Implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis Requirements pip install -r requi

Multimodal Lab @ Samsung AI Center Moscow 201 Dec 21, 2022
Code for Boundary-Aware Segmentation Network for Mobile and Web Applications

BASNet Boundary-Aware Segmentation Network for Mobile and Web Applications This repository contain implementation of BASNet in tensorflow/keras. comme

Hamid Ali 8 Nov 24, 2022
Implemenets the Contourlet-CNN as described in C-CNN: Contourlet Convolutional Neural Networks, using PyTorch

C-CNN: Contourlet Convolutional Neural Networks This repo implemenets the Contourlet-CNN as described in C-CNN: Contourlet Convolutional Neural Networ

Goh Kun Shun (KHUN) 10 Nov 03, 2022
Set of models for classifcation of 3D volumes

Classification models 3D Zoo - Keras and TF.Keras This repository contains 3D variants of popular CNN models for classification like ResNets, DenseNet

69 Dec 28, 2022
PyTorch wrappers for using your model in audacity!

audacitorch This package contains utilities for prepping PyTorch audio models for use in Audacity. More specifically, it provides abstract classes for

Hugo Flores García 130 Dec 14, 2022
Official Repository for "Robust On-Policy Data Collection for Data Efficient Policy Evaluation" (NeurIPS 2021 Workshop on OfflineRL).

Robust On-Policy Data Collection for Data-Efficient Policy Evaluation Source code of Robust On-Policy Data Collection for Data-Efficient Policy Evalua

Autonomous Agents Research Group (University of Edinburgh) 2 Oct 09, 2022
Simple Pose: Rethinking and Improving a Bottom-up Approach for Multi-Person Pose Estimation

SimplePose Code and pre-trained models for our paper, “Simple Pose: Rethinking and Improving a Bottom-up Approach for Multi-Person Pose Estimation”, a

Jia Li 256 Dec 24, 2022
PyTorch implementation of HDN(Homography Decomposition Networks) for planar object tracking

Homography Decomposition Networks for Planar Object Tracking This project is the offical PyTorch implementation of HDN(Homography Decomposition Networ

CaptainHook 48 Dec 15, 2022
Pytorch implementation of COIN, a framework for compression with implicit neural representations 🌸

COIN 🌟 This repo contains a Pytorch implementation of COIN: COmpression with Implicit Neural representations, including code to reproduce all experim

Emilien Dupont 104 Dec 14, 2022
NPBG++: Accelerating Neural Point-Based Graphics

[CVPR 2022] NPBG++: Accelerating Neural Point-Based Graphics Project Page | Paper This repository contains the official Python implementation of the p

Ruslan Rakhimov 57 Dec 03, 2022
MMRazor: a model compression toolkit for model slimming and AutoML

Documentation: English | 简体中文 Introduction MMRazor is a model compression toolkit for model slimming and AutoML, which

OpenMMLab 899 Jan 02, 2023
Augmented Traffic Control: A tool to simulate network conditions

Augmented Traffic Control Full documentation for the project is available at Overview Augmented

Meta Archive 4.3k Jan 08, 2023
Official PyTorch Implementation of HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning (NeurIPS 2021 Spotlight)

[NeurIPS 2021 Spotlight] HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning [Paper] This is Official PyTorch implementatio

42 Nov 01, 2022
E2VID_ROS - E2VID_ROS: E2VID to a real-time system

E2VID_ROS Introduce We extend E2VID to a real-time system. Because Python ROS ca

Robin Shaun 7 Apr 17, 2022