PyTorch implementation of GLOM

Overview

GLOM

PyTorch implementation of GLOM, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns).

1. Overview

An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" for MNIST Dataset.

2. Usage

2 - 1. PyTorch version

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)

Pass the return_all = True keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.

It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)

# get the top level outputs after iteration 6
top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)

Denoising self-supervised learning for encouraging emergence, as described by Hinton

import torch
import torch.nn.functional as F
from torch import nn
from einops.layers.torch import Rearrange

from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
noised_img = img + torch.randn_like(img)

all_levels = model(noised_img, return_all = True)

patches_to_images = nn.Sequential(
    nn.Linear(512, 14 * 14 * 3),
    Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
)

top_level = all_levels[7, :, :, -1]  # get the top level embeddings after iteration 6
recon_img = patches_to_images(top_level)

# do self-supervised learning by denoising

loss = F.mse_loss(img, recon_img)
loss.backward()

You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,
    levels = 6,
    image_size = 224,
    patch_size = 14
)

img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)

levels1 = model(img1, iters = 12)                   # image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6)  # image 3 for 6 iterations

2 - 2. PyTorch-Lightning version

The pyglom also provides the GLOM model that is implemented with PyTorch-Lightning.

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import os
from pytorch_lightning.callbacks import ModelCheckpoint


from pyglom.glom import LightningGLOM


dataset = MNIST(os.getcwd(), download=True, transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
]))
train, val = random_split(dataset, [55000, 5000])

glom = LightningGLOM(
    dim=256,         # dimension
    levels=6,        # number of levels
    image_size=256,  # image size
    patch_size=16,   # patch size
    img_channels=1
)

gpus = torch.cuda.device_count()
trainer = pl.Trainer(gpus=gpus, max_epochs=5)
trainer.fit(glom, DataLoader(train, batch_size=8, num_workers=2), DataLoader(val, batch_size=8, num_workers=2))

3. ToDo

  • contrastive / consistency regularization of top-ish levels

4. Citations

@misc{hinton2021represent,
    title   = {How to represent part-whole hierarchies in a neural network}, 
    author  = {Geoffrey Hinton},
    year    = {2021},
    eprint  = {2102.12627},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

A bunch of random PyTorch models using PyTorch's C++ frontend
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Releases(0.0.3)
Owner
Yeonwoo Sung
2020-09-21 ~ 2022-06-20 RoK (Korea) Air Force
Yeonwoo Sung
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma 🔥 News 2021-10

Jingtao Zhan 99 Dec 27, 2022
Pyramid R-CNN: Towards Better Performance and Adaptability for 3D Object Detection

Pyramid R-CNN: Towards Better Performance and Adaptability for 3D Object Detection

61 Jan 07, 2023
An Extendible (General) Continual Learning Framework based on Pytorch - official codebase of Dark Experience for General Continual Learning

Mammoth - An Extendible (General) Continual Learning Framework for Pytorch NEWS STAY TUNED: We are working on an update of this repository to include

AImageLab 277 Dec 28, 2022
catch-22: CAnonical Time-series CHaracteristics

catch22 - CAnonical Time-series CHaracteristics About catch22 is a collection of 22 time-series features coded in C that can be run from Python, R, Ma

Carl H Lubba 229 Oct 21, 2022
Credit fraud detection in Python using a Jupyter Notebook

Credit-Fraud-Detection - Credit fraud detection in Python using a Jupyter Notebook , using three classification models (Random Forest, Gaussian Naive Bayes, Logistic Regression) from the sklearn libr

Ali Akram 4 Dec 28, 2021
Train robotic agents to learn pick and place with deep learning for vision-based manipulation in PyBullet.

Ravens is a collection of simulated tasks in PyBullet for learning vision-based robotic manipulation, with emphasis on pick and place. It features a Gym-like API with 10 tabletop rearrangement tasks,

Google Research 367 Jan 09, 2023
HuSpaCy: industrial-strength Hungarian natural language processing

HuSpaCy: Industrial-strength Hungarian NLP HuSpaCy is a spaCy model and a library providing industrial-strength Hungarian language processing faciliti

HuSpaCy 120 Dec 14, 2022
Distributed Deep learning with Keras & Spark

Elephas: Distributed Deep Learning with Keras & Spark Elephas is an extension of Keras, which allows you to run distributed deep learning models at sc

Max Pumperla 1.6k Jan 05, 2023
Code repo for "FASA: Feature Augmentation and Sampling Adaptation for Long-Tailed Instance Segmentation" (ICCV 2021)

FASA: Feature Augmentation and Sampling Adaptation for Long-Tailed Instance Segmentation (ICCV 2021) This repository contains the implementation of th

Yuhang Zang 21 Dec 17, 2022
Shape Matching of Real 3D Object Data to Synthetic 3D CADs (3DV project @ ETHZ)

Real2CAD-3DV Shape Matching of Real 3D Object Data to Synthetic 3D CADs (3DV project @ ETHZ) Group Member: Yue Pan, Yuanwen Yue, Bingxin Ke, Yujie He

24 Jun 22, 2022
Colab notebook and additional materials for Python-driven analysis of redlining data in Philadelphia

RedliningExploration The Google Colaboratory file contained in this repository contains work inspired by a project on educational inequality in the Ph

Benjamin Warren 1 Jan 20, 2022
The author's officially unofficial PyTorch BigGAN implementation.

BigGAN-PyTorch The author's officially unofficial PyTorch BigGAN implementation. This repo contains code for 4-8 GPU training of BigGANs from Large Sc

Andy Brock 2.6k Jan 02, 2023
Alternatives to Deep Neural Networks for Function Approximations in Finance

Alternatives to Deep Neural Networks for Function Approximations in Finance Code companion repo Overview This is a repository of Python code to go wit

15 Dec 17, 2022
Data and analysis code for an MS on SK VOC genomes phenotyping/neutralisation assays

Description Summary of phylogenomic methods and analyses used in "Immunogenicity of convalescent and vaccinated sera against clinical isolates of ance

Finlay Maguire 1 Jan 06, 2022
Public implementation of "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression" from CoRL'21

Self-Supervised Reward Regression (SSRR) Codebase for CoRL 2021 paper "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression "

19 Dec 12, 2022
VR Viewport Pose Model for Quantifying and Exploiting Frame Correlations

This repository contains the introduction to the collected VRViewportPose dataset and the code for the IEEE INFOCOM 2022 paper: "VR Viewport Pose Model for Quantifying and Exploiting Frame Correlatio

0 Aug 10, 2022
g9.py - Torch interactive graphics

g9.py - Torch interactive graphics A Torch toy in the browser. Demo at https://srush.github.io/g9py/ This is a shameless copy of g9.js, written in Pyt

Sasha Rush 13 Nov 16, 2022
Faster Convex Lipschitz Regression

Faster Convex Lipschitz Regression This reepository provides a python implementation of our Faster Convex Lipschitz Regression algorithm with GPU and

Ali Siahkamari 0 Nov 19, 2021
You Only 👀 One Sequence

You Only 👀 One Sequence TL;DR: We study the transferability of the vanilla ViT pre-trained on mid-sized ImageNet-1k to the more challenging COCO obje

Hust Visual Learning Team 666 Jan 03, 2023
A high performance implementation of HDBSCAN clustering.

HDBSCAN HDBSCAN - Hierarchical Density-Based Spatial Clustering of Applications with Noise. Performs DBSCAN over varying epsilon values and integrates

2.3k Jan 02, 2023