An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" in Pytorch.

Related tags

Deep LearningGLOM
Overview

GLOM

An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" for MNIST Dataset. To understand this implementation, please watch Yannick Kilcher's GLOM video, then read this README.md, then read the code.

Running

Open in jupyter notebook to run. Program expects an Nvidia graphics card for gpu speedup. If you run out of gpu memory, decrease the batch_size variable. If you want to look at the code on github and it fails, try reloading or refreshing several times.

Results

The best models, which have been posted under the best_models folder, reached an accuracy of about 91%.

Implementation details

Three Types of networks per layer of vectors

  1. Top-Down Network
  2. Bottom-up Network
  3. Attention on the same layer Network

Intro to State

There is an initial state that all three types of network outputs get added to after every time step. The bottom layer of the state is the input vector where the MNIST pixel data is kept and doesn't get anything added to it to retain the MNIST pixel data. The top layer of the state is the output layer where the loss function is applied and trained to be the one-hot MNIST target vector.

Explanation of compute_all function

Each type of network will see a 3x3 grid of vectors surrounding the current network input vector at the current layer. This is done to allow information to travel faster laterally across vectors, allowing for more information to be sent across an image in less steps. The easy way to do this is to shift (or roll) every vector along the x and y axis and then concatenate the vectors ontop of eachother so that every place a vector used to be in the state, now contains every vector and its neighboring vectors in the same layer. This also connects the edges of the image so that data can be passed from one edge of the image to the other, reducing the maximum distance any two pixels or vectors can be from one another.

For a more complex dataset, its possible this could pose some issues since two separate edges of an image aren't generally continous, but for MNIST, this problem doesn't arise. Then, these vectors are fed to each type of model. The models will get an input of all neighboring state vectors for a certain layer for each pixel that is given. Each model will then output a single vector. But there are 3 types of models per layer. In this example, every line drawn is a new model that is reused for every pixel this process is done for. After each model type has given an output, the three lists of vectors are added together.

This will give a single list of vectors that will be added to the corresponding list of vectors at the specific x,y coordinate from the original state.

Repeating this step for every list of vectors per x,y coordinate in the original state will yield the full new State value.

Since each network only sees a 3x3 grid and not larger image patches, this technique can be used for any size images and is easily parrallelizable.

If I had more compute

My 2080Ti runs into memory errors running this if the batch size is above around 30, so here are my implementatin ideas if I had more compute.

  1. Increase batch_size. This probably wont affect the training, but it would make testing the accuracy faster.
  2. Saving more states throughout the steps taken and adding them together. This would allow for gradients to get passed back to the original state similar to how RESNET can train very large model since the gradients can get passed backwards easier. This has been implemented to a smaller degree already and showed massive accuracy improvements.
  3. Perform some kind of evolutionary parameter search by mutating the model parameters while also using backprop. This has been shown to improve the accuracy of image classifiers and other models. But this would take a ton of compute.

Yannic Kilcher's Attention

This hass been pushed to github because during testing and tuning hyperparameters, a better model than previous was found. More testing needs to be done and I'm working on the visual explanation for it now. Previous versions of this code don't have the attention seen in the current version and will have similar performance.

Other Ideas behind the paper implementation

This is basically a neural cellular automata from the paper Growing Neural Cellular Automata with some inspiration from the follow up paper Self-classifying MNIST Digits. Except instead of a single list of numbers (or one vector) per pixel, there are several vectors per pixel in each image. The Growing Neural Cellular Automata paper was very difficult to train also because the long gradient chains, so increasing the models complexity in this GLOM paper makes training even harder. But the neural cellular automata papers are the reason why the MSE loss function is used while also adding random noise to the state during training.

To do

  1. Generated the explanation for Yannick Kilcher's version of attention that is implemented here.
  2. See if part-whole heirarchies are being found.
  3. Keep testing hyperpatameters to push accuracy higher.
  4. Test different state initializations.
  5. Train on harder datasets.

If you find any issues, please feel free to contact me

Owner
Just a random coder
Python package facilitating the use of Bayesian Deep Learning methods with Variational Inference for PyTorch

PyVarInf PyVarInf provides facilities to easily train your PyTorch neural network models using variational inference. Bayesian Deep Learning with Vari

342 Dec 02, 2022
A collection of models for image<->text generation in ACM MM 2021.

Bi-directional Image and Text Generation UMT-BITG (image & text generator) Unifying Multimodal Transformer for Bi-directional Image and Text Generatio

Multimedia Research 63 Oct 30, 2022
Jupyter Dock is a set of Jupyter Notebooks for performing molecular docking protocols interactively, as well as visualizing, converting file formats and analyzing the results.

Molecular Docking integrated in Jupyter Notebooks Description | Citation | Installation | Examples | Limitations | License Table of content Descriptio

Angel J. Ruiz Moreno 173 Dec 25, 2022
Build and run Docker containers leveraging NVIDIA GPUs

NVIDIA Container Toolkit Introduction The NVIDIA Container Toolkit allows users to build and run GPU accelerated Docker containers. The toolkit includ

NVIDIA Corporation 15.6k Jan 01, 2023
Code repository for the work "Multi-Domain Incremental Learning for Semantic Segmentation", accepted at WACV 2022

Multi-Domain Incremental Learning for Semantic Segmentation This is the Pytorch implementation of our work "Multi-Domain Incremental Learning for Sema

Pgxo20 24 Jan 02, 2023
[cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation

PS-MT [cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation by Yuyuan Liu, Yu Tian, Yuanhong Chen, Fengbei Liu, Vasile

Yuyuan Liu 132 Jan 03, 2023
Pytorch implementation of the paper "Enhancing Content Preservation in Text Style Transfer Using Reverse Attention and Conditional Layer Normalization"

Pytorch implementation of the paper "Enhancing Content Preservation in Text Style Transfer Using Reverse Attention and Conditional Layer Normalization"

Dongkyu Lee 4 Sep 18, 2022
DRIFT is a tool for Diachronic Analysis of Scientific Literature.

About DRIFT is a tool for Diachronic Analysis of Scientific Literature. The application offers user-friendly and customizable utilities for two modes:

Rajaswa Patil 108 Dec 12, 2022
A concise but complete implementation of CLIP with various experimental improvements from recent papers

x-clip (wip) A concise but complete implementation of CLIP with various experimental improvements from recent papers Install $ pip install x-clip Usag

Phil Wang 515 Dec 26, 2022
Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

Label-Efficient Semantic Segmentation with Diffusion Models Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion

Yandex Research 355 Jan 06, 2023
TART - A PyTorch implementation for Transition Matrix Representation of Trees with Transposed Convolutions

TART This project is a PyTorch implementation for Transition Matrix Representati

Lee Sael 2 Jan 19, 2022
Code for reproducible experiments presented in KSD Aggregated Goodness-of-fit Test.

Code for KSDAgg: a KSD aggregated goodness-of-fit test This GitHub repository contains the code for the reproducible experiments presented in our pape

Antonin Schrab 5 Dec 15, 2022
Implementation of "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement" by pytorch

This repository is used to suspend the results of our paper "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement"

ScorpioMiku 19 Sep 30, 2022
Good Semi-Supervised Learning That Requires a Bad GAN

Good Semi-Supervised Learning that Requires a Bad GAN This is the code we used in our paper Good Semi-supervised Learning that Requires a Bad GAN Ziha

Zhilin Yang 177 Dec 12, 2022
Bringing sanity to world of messed-up data

Sanitize sanitize is a Python module for making sure various things (e.g. HTML) are safe to use. It was originally written by Mark Pilgrim and is dist

Alireza Savand 63 Oct 26, 2021
PyTorch-centric library for evaluating and enhancing the robustness of AI technologies

Responsible AI Toolbox A library that provides high-quality, PyTorch-centric tools for evaluating and enhancing both the robustness and the explainabi

24 Dec 22, 2022
Prososdy Morph: A python library for manipulating pitch and duration in an algorithmic way, for resynthesizing speech.

ProMo (Prosody Morph) Questions? Comments? Feedback? Chat with us on gitter! A library for manipulating pitch and duration in an algorithmic way, for

Tim 71 Jan 02, 2023
Heat transfer problemas solved using python

heat-transfer Heat transfer problems solved using python isolation-convection.py compares the temperature distribution on the problem as shown in the

2 Nov 14, 2021
Toolbox of models, callbacks, and datasets for AI/ML researchers.

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch Website • Installation • Main

Pytorch Lightning 1.4k Dec 30, 2022
An implementation of Video Frame Interpolation via Adaptive Separable Convolution using PyTorch

This work has now been superseded by: https://github.com/sniklaus/revisiting-sepconv sepconv-slomo This is a reference implementation of Video Frame I

Simon Niklaus 984 Dec 16, 2022