Official implementation of the paper "Topographic VAEs learn Equivariant Capsules"

Overview

Topographic Variational Autoencoder

Paper: https://arxiv.org/abs/2109.01394

Getting Started

Install requirements with Anaconda:

conda env create -f environment.yml

Activate the conda environment

conda activate tvae

Install the tvae package

Install the tvae package inside of your conda environment. This allows you to run experiments with the tvae command. At the root of the project directory run (using your environment's pip): pip3 install -e .

If you need help finding your environment's pip, try which python, which should point you to a directory such as .../anaconda3/envs/tvae/bin/ where it will be located.

(Optional) Setup Weights & Biases:

This repository uses Weight & Biases for experiment tracking. By deafult this is set to off. However, if you would like to use this (highly recommended!) functionality, all you have to do is set 'wandb_on': True in the experiment config, and set your account's project and entity names in the tvae/utils/logging.py file.

For more information on making a Weight & Biases account see (creating a weights and biases account) and the associated quickstart guide.

Running an experiment

To rerun the experiment from Figure 3, you can run:

  • tvae --name 'tvae_2d_mnist'

To rerun the experiments from Figure 4, you can run:

  • tvae --name 'tvae_Lpartial_mnist'
  • tvae --name 'tvae_Lpartial_dsprites'

To rerun the experiments from Tables 1, you can run:

  • tvae --name 'tvae_Lhalf_mnist'
  • tvae --name 'tvae_Lshort_mnist'
  • tvae --name 'bubbles_mnist'
  • tvae --name 'tvae_L0_mnist'
  • tvae --name 'nontvae_mnist'

To rerun the experiments from Tables 2, you can run:

  • tvae --name 'tvae_Lhalf_dsprites'
  • tvae --name 'tvae_Lpartial_dsprites'
  • tvae --name 'tvae_Lshort_dsprites'
  • tvae --name 'bubbles_dsprites'
  • tvae --name 'tvae_L0_dsprites'
  • tvae --name 'nontvae_dsprites'

To rerun the generalization experiment described in Section B.4 (resulting in Figures 1 and 6), you can run:

  • tvae --name 'tvae_Lpartial_mnist_generalization'

To rerun the experiments from Figures 22 and 23 (training on complex combined transformations), you can run:

  • tvae --name 'tvae_Lpartial_perspective_mnist'
  • tvae --name 'tvae_Lpartial_rotcolor_mnist'

Basics of the framework

  • All models are built using the TVAE module (see tvae/containers/tvae.py) which requires a z-encoder, a u-encoder, a decoder, and a 'grouper'. The grouper module defines the topographic structure of the latent space through a model (equivalent to W in the paper), and a padder which defines the boundary conditions.
  • All experiments can be found in tvae/experiments/, and begin with the model specification, followed by the experiment config where important values such as L (group_kernel) and K (n_off_diag) can be set.

Model Architecutre Options

  • 'n_caps': int, Number of independnt capsules
  • 'cap_dim': int, Size of each capsule
  • 'n_transforms': int, Length of the total transformation sequence (denoted S in the paper)
  • 'mu_init': int, Initalization value for mu parameter
  • 'n_off_diag': int, determines the spatial extent of the grouping within a single timestep (denoted K in the paper), n_off_diag=1 gives K=3, while n_off_diag=0 gives K=1.
  • 'group_kernel': tuple of int, defines the size of the kernel used by the grouper, exact definition and relationship to W varies for each experiment.

Training Options

  • 'wandb_on': bool, if True, use weights & biases logging
  • 'lr': float, learning rate
  • 'momentum': float, standard momentum used in SGD
  • 'max_epochs': int, total training epochs
  • 'eval_epochs': int, epochs between evaluation on the test (for MNIST)
  • 'batch_size': int, number of samples per batch
  • 'n_is_samples': int, number of importance samples when computing the log-likelihood on MNIST.
  • 'max_transform_len': int, (for dSprites) controls the subset of the dataset

Acknowledgements

The Robert Bosch GmbH is acknowledged for financial support.

Owner
T. Andy Keller
PhD Student at UvA
T. Andy Keller
Utility tools for the "Divide and Remaster" dataset, introduced as part of the Cocktail Fork problem paper

Divide and Remaster Utility Tools Utility tools for the "Divide and Remaster" dataset, introduced as part of the Cocktail Fork problem paper The DnR d

Darius Petermann 46 Dec 11, 2022
Graph Representation Learning via Graphical Mutual Information Maximization

GMI (Graphical Mutual Information) Graph Representation Learning via Graphical Mutual Information Maximization (Peng Z, Huang W, Luo M, et al., WWW 20

93 Dec 29, 2022
A little software to generate and save Julia or Mandelbrot's Fractals.

Julia-Mandelbrot-s-Fractals A little software to generate and save Julia or Mandelbrot's Fractals. Dependencies : Python 3.7 or more. (Also possible t

Olivier 0 Jul 09, 2022
Code for the paper "Improving Vision-and-Language Navigation with Image-Text Pairs from the Web" (ECCV 2020)

Improving Vision-and-Language Navigation with Image-Text Pairs from the Web Arjun Majumdar, Ayush Shrivastava, Stefan Lee, Peter Anderson, Devi Parikh

Arjun Majumdar 44 Dec 14, 2022
MLSpace: Hassle-free machine learning & deep learning development

MLSpace: Hassle-free machine learning & deep learning development

abhishek thakur 293 Jan 03, 2023
SubOmiEmbed: Self-supervised Representation Learning of Multi-omics Data for Cancer Type Classification

SubOmiEmbed: Self-supervised Representation Learning of Multi-omics Data for Cancer Type Classification

Sayed Hashim 3 Nov 15, 2022
Barlow Twins and HSIC

Barlow Twins and HSIC Unofficial Pytorch implementation for Barlow Twins and HSIC_SSL on small datasets (CIFAR10, STL10, and Tiny ImageNet). Correspon

Yao-Hung Hubert Tsai 49 Nov 24, 2022
OpenMMLab Semantic Segmentation Toolbox and Benchmark.

Documentation: https://mmsegmentation.readthedocs.io/ English | 简体中文 Introduction MMSegmentation is an open source semantic segmentation toolbox based

OpenMMLab 5k Dec 31, 2022
TensorFlow implementation of the paper "Hierarchical Attention Networks for Document Classification"

Hierarchical Attention Networks for Document Classification This is an implementation of the paper Hierarchical Attention Networks for Document Classi

Quoc-Tuan Truong 83 Dec 05, 2022
Hybrid Neural Fusion for Full-frame Video Stabilization

FuSta: Hybrid Neural Fusion for Full-frame Video Stabilization Project Page | Video | Paper | Google Colab Setup Setup environment for [Yu and Ramamoo

Yu-Lun Liu 430 Jan 04, 2023
A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering.

DeepFilterNet A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering. libDF contains Rust code used for dat

Hendrik Schröter 292 Dec 25, 2022
optimization routines for hyperparameter tuning

Hyperopt: Distributed Hyperparameter Optimization Hyperopt is a Python library for serial and parallel optimization over awkward search spaces, which

Marc Claesen 398 Nov 09, 2022
This is the official code of our paper "Diversity-based Trajectory and Goal Selection with Hindsight Experience Relay" (PRICAI 2021)

Diversity-based Trajectory and Goal Selection with Hindsight Experience Replay This is the official implementation of our paper "Diversity-based Traje

Tianhong Dai 6 Jul 18, 2022
SMPLpix: Neural Avatars from 3D Human Models

subject0_validation_poses.mp4 Left: SMPL-X human mesh registered with SMPLify-X, middle: SMPLpix render, right: ground truth video. SMPLpix: Neural Av

Sergey Prokudin 292 Dec 30, 2022
An Easy-to-use, Modular and Prolongable package of deep-learning based Named Entity Recognition Models.

DeepNER An Easy-to-use, Modular and Prolongable package of deep-learning based Named Entity Recognition Models. This repository contains complex Deep

Derrick 9 May 30, 2022
QuadTree Attention for Vision Transformers (ICLR2022)

This repository contains codes for quadtree attention. This repo contains codes for feature matching, image classficiation, object detection and seman

tangshitao 222 Dec 28, 2022
This repo is the code release of EMNLP 2021 conference paper "Connect-the-Dots: Bridging Semantics between Words and Definitions via Aligning Word Sense Inventories".

Connect-the-Dots: Bridging Semantics between Words and Definitions via Aligning Word Sense Inventories This repo is the code release of EMNLP 2021 con

12 Nov 22, 2022
Large scale PTM - PPI relation extraction

Large-scale protein-protein post-translational modification extraction with distant supervision and confidence calibrated BioBERT The silver standard

1 Feb 25, 2022
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.

Deep Image Search - AI-Based Image Search Engine Deep Image Search is an AI-based image search engine that includes deep transfer learning features Ex

139 Jan 01, 2023
Unofficial Pytorch Lightning implementation of Contrastive Syn-to-Real Generalization (ICLR, 2021)

Unofficial Pytorch Lightning implementation of Contrastive Syn-to-Real Generalization (ICLR, 2021)

Gyeongjae Choi 17 Sep 23, 2021