PyTorch implementation of GLOM

Overview

GLOM

PyTorch implementation of GLOM, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns).

1. Overview

An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" for MNIST Dataset.

2. Usage

2 - 1. PyTorch version

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)

Pass the return_all = True keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.

It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)

# get the top level outputs after iteration 6
top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)

Denoising self-supervised learning for encouraging emergence, as described by Hinton

import torch
import torch.nn.functional as F
from torch import nn
from einops.layers.torch import Rearrange

from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
noised_img = img + torch.randn_like(img)

all_levels = model(noised_img, return_all = True)

patches_to_images = nn.Sequential(
    nn.Linear(512, 14 * 14 * 3),
    Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
)

top_level = all_levels[7, :, :, -1]  # get the top level embeddings after iteration 6
recon_img = patches_to_images(top_level)

# do self-supervised learning by denoising

loss = F.mse_loss(img, recon_img)
loss.backward()

You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,
    levels = 6,
    image_size = 224,
    patch_size = 14
)

img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)

levels1 = model(img1, iters = 12)                   # image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6)  # image 3 for 6 iterations

2 - 2. PyTorch-Lightning version

The pyglom also provides the GLOM model that is implemented with PyTorch-Lightning.

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import os
from pytorch_lightning.callbacks import ModelCheckpoint


from pyglom.glom import LightningGLOM


dataset = MNIST(os.getcwd(), download=True, transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
]))
train, val = random_split(dataset, [55000, 5000])

glom = LightningGLOM(
    dim=256,         # dimension
    levels=6,        # number of levels
    image_size=256,  # image size
    patch_size=16,   # patch size
    img_channels=1
)

gpus = torch.cuda.device_count()
trainer = pl.Trainer(gpus=gpus, max_epochs=5)
trainer.fit(glom, DataLoader(train, batch_size=8, num_workers=2), DataLoader(val, batch_size=8, num_workers=2))

3. ToDo

  • contrastive / consistency regularization of top-ish levels

4. Citations

@misc{hinton2021represent,
    title   = {How to represent part-whole hierarchies in a neural network}, 
    author  = {Geoffrey Hinton},
    year    = {2021},
    eprint  = {2102.12627},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

A bunch of random PyTorch models using PyTorch's C++ frontend
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Releases(0.0.3)
Owner
Yeonwoo Sung
2020-09-21 ~ 2022-06-20 RoK (Korea) Air Force
Yeonwoo Sung
Ian Covert 130 Jan 01, 2023
TLoL (Python Module) - League of Legends Deep Learning AI (Research and Development)

TLoL-py - League of Legends Deep Learning Library TLoL-py is the Python component of the TLoL League of Legends deep learning library. It provides a s

7 Nov 29, 2022
CT-Net: Channel Tensorization Network for Video Classification

[ICLR2021] CT-Net: Channel Tensorization Network for Video Classification @inproceedings{ li2021ctnet, title={{\{}CT{\}}-Net: Channel Tensorization Ne

33 Nov 15, 2022
Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation. In CVPR 2022.

Nonuniform-to-Uniform Quantization This repository contains the training code of N2UQ introduced in our CVPR 2022 paper: "Nonuniform-to-Uniform Quanti

Zechun Liu 60 Dec 28, 2022
JupyterNotebook - C/C++, Javascript, HTML, LaTex, Shell scripts in Jupyter Notebook Also run them on remote computer

JupyterNotebook Read, write and execute C, C++, Javascript, Shell scripts, HTML, LaTex in jupyter notebook, And also execute them on remote computer R

1 Jan 09, 2022
Deep Reinforcement Learning for Keras.

Deep Reinforcement Learning for Keras What is it? keras-rl implements some state-of-the art deep reinforcement learning algorithms in Python and seaml

Keras-RL 0 Dec 15, 2022
A simple Tensorflow based library for deep and/or denoising AutoEncoder.

libsdae - deep-Autoencoder & denoising autoencoder A simple Tensorflow based library for Deep autoencoder and denoising AE. Library follows sklearn st

Rajarshee Mitra 147 Nov 18, 2022
Repositorio oficial del curso IIC2233 Programación Avanzada 🚀✨

IIC2233 - Programación Avanzada Evaluación Las evaluaciones serán efectuadas por medio de actividades prácticas en clases y tareas. Se calculará la no

IIC2233 @ UC 47 Sep 06, 2022
Code and models for ICCV2021 paper "Robust Object Detection via Instance-Level Temporal Cycle Confusion".

Robust Object Detection via Instance-Level Temporal Cycle Confusion This repo contains the implementation of the ICCV 2021 paper, Robust Object Detect

Xin Wang 69 Oct 13, 2022
A Python parser that takes the content of a text file and then reads it into variables.

Text-File-Parser A Python parser that takes the content of a text file and then reads into variables. Input.text File 1. What is your ***? 1. 18 -

Kelvin 0 Jul 26, 2021
Facial Expression Detection In The Realtime

The human's facial expressions is very important to detect thier emotions and sentiment. It can be very efficient to use to make our computers make interviews. Furthermore, we have robots now can det

Adel El-Nabarawy 4 Mar 01, 2022
Code for Massive-scale Decoding for Text Generation using Lattices

Massive-scale Decoding for Text Generation using Lattices Jiacheng Xu, Greg Durrett TL;DR: a new search algorithm to construct lattices encoding many

Jiacheng Xu 37 Dec 18, 2022
Revisiting Temporal Alignment for Video Restoration

Revisiting Temporal Alignment for Video Restoration [arXiv] Kun Zhou, Wenbo Li, Liying Lu, Xiaoguang Han, Jiangbo Lu We provide our results at Google

52 Dec 25, 2022
A Gura parser implementation for Python

Gura Python parser This repository contains the implementation of a Gura (compliant with version 1.0.0) format parser in Python. Installation pip inst

Gura Config Lang 19 Jan 25, 2022
Pytorch implementation of paper: "NeurMiPs: Neural Mixture of Planar Experts for View Synthesis"

NeurMips: Neural Mixture of Planar Experts for View Synthesis This is the official repo for PyTorch implementation of paper "NeurMips: Neural Mixture

James Lin 101 Dec 13, 2022
The source code and data of the paper "Instance-wise Graph-based Framework for Multivariate Time Series Forecasting".

IGMTF The source code and data of the paper "Instance-wise Graph-based Framework for Multivariate Time Series Forecasting". Requirements The framework

Wentao Xu 24 Dec 05, 2022
Mind the Trade-off: Debiasing NLU Models without Degrading the In-distribution Performance

Models for natural language understanding (NLU) tasks often rely on the idiosyncratic biases of the dataset, which make them brittle against test cases outside the training distribution.

Ubiquitous Knowledge Processing Lab 22 Jan 02, 2023
House3D: A Rich and Realistic 3D Environment

House3D: A Rich and Realistic 3D Environment Yi Wu, Yuxin Wu, Georgia Gkioxari and Yuandong Tian House3D is a virtual 3D environment which consists of

Meta Research 1.1k Dec 14, 2022
This repository contains project created during the Data Challenge module at London School of Hygiene & Tropical Medicine

LSHTM_RCS This repository contains project created during the Data Challenge module at London School of Hygiene & Tropical Medicine (LSHTM) in collabo

Lukas Kopecky 3 Jan 30, 2022
This is a simple face recognition mini project that was completed by a team of 3 members in 1 week's time

PeekingDuckling 1. Description This is an implementation of facial identification algorithm to detect and identify the faces of the 3 team members Cla

Eric Kwok 2 Jan 25, 2022