PyTorch implementation of GLOM

Overview

GLOM

PyTorch implementation of GLOM, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns).

1. Overview

An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" for MNIST Dataset.

2. Usage

2 - 1. PyTorch version

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)

Pass the return_all = True keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.

It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)

# get the top level outputs after iteration 6
top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)

Denoising self-supervised learning for encouraging emergence, as described by Hinton

import torch
import torch.nn.functional as F
from torch import nn
from einops.layers.torch import Rearrange

from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
noised_img = img + torch.randn_like(img)

all_levels = model(noised_img, return_all = True)

patches_to_images = nn.Sequential(
    nn.Linear(512, 14 * 14 * 3),
    Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
)

top_level = all_levels[7, :, :, -1]  # get the top level embeddings after iteration 6
recon_img = patches_to_images(top_level)

# do self-supervised learning by denoising

loss = F.mse_loss(img, recon_img)
loss.backward()

You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,
    levels = 6,
    image_size = 224,
    patch_size = 14
)

img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)

levels1 = model(img1, iters = 12)                   # image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6)  # image 3 for 6 iterations

2 - 2. PyTorch-Lightning version

The pyglom also provides the GLOM model that is implemented with PyTorch-Lightning.

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import os
from pytorch_lightning.callbacks import ModelCheckpoint


from pyglom.glom import LightningGLOM


dataset = MNIST(os.getcwd(), download=True, transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
]))
train, val = random_split(dataset, [55000, 5000])

glom = LightningGLOM(
    dim=256,         # dimension
    levels=6,        # number of levels
    image_size=256,  # image size
    patch_size=16,   # patch size
    img_channels=1
)

gpus = torch.cuda.device_count()
trainer = pl.Trainer(gpus=gpus, max_epochs=5)
trainer.fit(glom, DataLoader(train, batch_size=8, num_workers=2), DataLoader(val, batch_size=8, num_workers=2))

3. ToDo

  • contrastive / consistency regularization of top-ish levels

4. Citations

@misc{hinton2021represent,
    title   = {How to represent part-whole hierarchies in a neural network}, 
    author  = {Geoffrey Hinton},
    year    = {2021},
    eprint  = {2102.12627},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

A bunch of random PyTorch models using PyTorch's C++ frontend
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Releases(0.0.3)
Owner
Yeonwoo Sung
2020-09-21 ~ 2022-06-20 RoK (Korea) Air Force
Yeonwoo Sung
Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite and .pb from .tflite.

tflite2tensorflow Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite and .pb from .tflite. 1. Supported Layers No. TFLite Layer TF

Katsuya Hyodo 214 Dec 29, 2022
On-device speech-to-index engine powered by deep learning.

On-device speech-to-index engine powered by deep learning.

Picovoice 30 Nov 24, 2022
Source code of our BMVC 2021 paper: AniFormer: Data-driven 3D Animation with Transformer

AniFormer This is the PyTorch implementation of our BMVC 2021 paper AniFormer: Data-driven 3D Animation with Transformer. Haoyu Chen, Hao Tang, Nicu S

24 Nov 02, 2022
A python module for scientific analysis of 3D objects based on VTK and Numpy

A lightweight and powerful python module for scientific analysis and visualization of 3d objects.

Marco Musy 1.5k Jan 06, 2023
Categorizing comments on YouTube into different categories.

Youtube Comments Categorization This repo is for categorizing comments on a youtube video into different categories. negative (grievances, complaints,

Rhitik 5 Nov 26, 2022
[TOG 2021] PyTorch implementation for the paper: SofGAN: A Portrait Image Generator with Dynamic Styling.

This repository contains the official PyTorch implementation for the paper: SofGAN: A Portrait Image Generator with Dynamic Styling. We propose a SofGAN image generator to decouple the latent space o

Anpei Chen 694 Dec 23, 2022
MMFlow is an open source optical flow toolbox based on PyTorch

Documentation: https://mmflow.readthedocs.io/ Introduction English | įŽ€äŊ“中文 MMFlow is an open source optical flow toolbox based on PyTorch. It is a part

OpenMMLab 688 Jan 06, 2023
Official pytorch implementation of paper Dual-Level Collaborative Transformer for Image Captioning (AAAI 2021).

Dual-Level Collaborative Transformer for Image Captioning This repository contains the reference code for the paper Dual-Level Collaborative Transform

lyricpoem 160 Dec 11, 2022
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

Qibin (Andrew) Hou 162 Nov 28, 2022
Banglore House Prediction Using Flask Server (Python)

Banglore House Prediction Using Flask Server (Python) 🌐 Links 🌐 📂 Repo In this repository, I've implemented a Machine Learning-based Bangalore Hous

Dhyan Shah 1 Jan 24, 2022
A free, multiplatform SDK for real-time facial motion capture using blendshapes, and rigid head pose in 3D space from any RGB camera, photo, or video.

mocap4face by Facemoji mocap4face by Facemoji is a free, multiplatform SDK for real-time facial motion capture based on Facial Action Coding System or

Facemoji 591 Dec 27, 2022
Prototypical Cross-Attention Networks for Multiple Object Tracking and Segmentation, NeurIPS 2021 Spotlight

PCAN for Multiple Object Tracking and Segmentation This is the offical implementation of paper PCAN for MOTS. We also present a trailer that consists

ETH VIS Group 328 Dec 29, 2022
Continual reinforcement learning baselines: experiment specifications, implementation of existing methods, and common metrics. Easily extensible to new methods.

Continual Reinforcement Learning This repository provides a simple way to run continual reinforcement learning experiments in PyTorch, including evalu

55 Dec 24, 2022
An off-line judger supporting distributed problem repositories

Thaw 中文 | English Thaw is an off-line judger supporting distributed problem repositories. Everyone can use Thaw release problems with license on GitHu

countercurrent_time 2 Jan 09, 2022
Real time sign language recognition

The proposed work aims at converting american sign language gestures into English that can be understood by everyone in real time.

Mohit Kaushik 6 Jun 13, 2022
Low-code/No-code approach for deep learning inference on devices

EzEdgeAI A concept project that uses a low-code/no-code approach to implement deep learning inference on devices. It provides a componentized framewor

On-Device AI Co., Ltd. 7 Apr 05, 2022
[CVPR 21] Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), 2021.

Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, CVPR 2021. Ayan Kumar Bhunia, Pinaki nath Chowdhury, Yongxin Yan

Ayan Kumar Bhunia 44 Dec 12, 2022
Code for Environment Inference for Invariant Learning (ICML 2020 UDL Workshop Paper)

Environment Inference for Invariant Learning This code accompanies the paper Environment Inference for Invariant Learning, which appears at ICML 2021.

Elliot Creager 40 Dec 09, 2022
Keras community contributions

keras-contrib : Keras community contributions Keras-contrib is deprecated. Use TensorFlow Addons. The future of Keras-contrib: We're migrating to tens

Keras 1.6k Dec 21, 2022
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

24 Dec 31, 2022