Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention

Overview

E(n)-Equivariant Transformer (wip)

Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention.

Install

$ pip install En-transformer

Usage

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    dim = 512,
    depth = 4,
    dim_head = 64,
    heads = 8,
    edge_dim = 4,
    fourier_features = 2
)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = model(feats, coors, edges)  # (1, 16, 512), (1, 16, 3)

Todo

  • masking
  • neighborhoods by radius

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Checkpoint sequential segments should equal number of layers instead of 1?

    Checkpoint sequential segments should equal number of layers instead of 1?

    https://github.com/lucidrains/En-transformer/blob/a37e635d93a322cafdaaf829397c601350b23e5b/en_transformer/en_transformer.py#L527

    Looking at the source code here: https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint_sequential

    opened by aced125 2
  • On rotary embeddings

    On rotary embeddings

    Hi @lucidrains, thank you for your amazing work; big fan! I had a quick question on the usage of this repository.

    Based on my understanding, rotary embeddings are a drop-in replacement for the original sinusoidal or learnt PEs in Transformers for sequential data, as in NLP or other temporal applications. If my application is not on sequential data, is there a reason why I should still use rotary embeddings?

    E.g. for molecular datasets such as QM9 (from the En-GNNs paper), would it make sense to have rotary embeddings?

    opened by chaitjo 1
  • Is this line required?

    Is this line required?

    https://github.com/lucidrains/En-transformer/blob/7247e258fab953b2a8b5a73b8dfdfb72910711f8/en_transformer/en_transformer.py#L159

    Is this line required? Does line 157, two lines above, make this line redundant?

    opened by aced125 1
  • Performance drop with checkpointing update

    Performance drop with checkpointing update

    I see a drop in performance (higher loss) when I update checkpointing from checkpoint_sequential(self.layers, 1, inp) to checkpoint_sequential(self.layers, len(self.layers), inp). Is this expected?

    opened by heiidii 0
  • varying number of nodes

    varying number of nodes

    @lucidrains Thank you for your efficient implementation. I was wondering how to use this implementation for the dataset when the number of nodes in each graph is not the same? For example, the datasets of small molecules.

    opened by mohaiminul2810 1
  • Edge model/rep

    Edge model/rep

    Hi,

    Thank you for providing this version of the EnGNN model. This is not really an issue just a query. The original model as implemented here (https://github.com/vgsatorras/egnn) has 3 main steps per layer: edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) h, agg = self.node_model(h, edge_index, edge_feat, node_attr) I am interested in the edge_feat and was wondering what would be an equivalent edge representation in your implementation. Line 335 in EnTransformer.py: qk = self.edge_mlp(qk) seems like the best candidate. Thanks, Pooja

    opened by heiidii 1
  • efficient implementation

    efficient implementation

    Hi, I wonder if relative distances and coordinates can be handled more efficiently using memory efficient attention as in " Self-attention Does Not Need O(n^2) Memory". It is straightforward for the scalar part.

    opened by amrhamedp 2
Releases(1.0.2)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
A tool for making map images from OpenTTD save games

OpenTTD Surveyor A tool for making map images from OpenTTD save games. This is not part of the main OpenTTD codebase, nor is it ever intended to be pa

Aidan Randle-Conde 9 Feb 15, 2022
Wenet STT Python

Wenet STT Python Beta Software Simple Python library, distributed via binary wheels with few direct dependencies, for easily using WeNet models for sp

David Zurow 33 Feb 21, 2022
Python package to generate image embeddings with CLIP without PyTorch/TensorFlow

imgbeddings A Python package to generate embedding vectors from images, using OpenAI's robust CLIP model via Hugging Face transformers. These image em

Max Woolf 81 Jan 04, 2023
基于Paddle框架的arcface复现

arcface-Paddle 基于Paddle框架的arcface复现 ArcFace-Paddle 本项目基于paddlepaddle框架复现ArcFace,并参加百度第三届论文复现赛,将在2021年5月15日比赛完后提供AIStudio链接~敬请期待 参考项目: InsightFace Padd

QuanHao Guo 16 Dec 15, 2022
Instantaneous Motion Generation for Robots and Machines.

Ruckig Instantaneous Motion Generation for Robots and Machines. Ruckig generates trajectories on-the-fly, allowing robots and machines to react instan

Berscheid 374 Dec 23, 2022
Anonymize BLM Protest Images

Anonymize BLM Protest Images This repository automates @BLMPrivacyBot, a Twitter bot that shows the anonymized images to help keep protesters safe. Us

Stanford Machine Learning Group 40 Oct 13, 2022
Implementation of the Chamfer Distance as a module for pyTorch

Chamfer Distance for pyTorch This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension.

Christian Diller 205 Jan 05, 2023
This is a Keras implementation of a CNN for estimating age, gender and mask from a camera.

face-detector-age-gender This is a Keras implementation of a CNN for estimating age, gender and mask from a camera. Before run face detector app, expr

Devdreamsolution 2 Dec 04, 2021
Simple-Image-Classification - Simple Image Classification Code (PyTorch)

Simple-Image-Classification Simple Image Classification Code (PyTorch) Yechan Kim This repository contains: Python3 / Pytorch code for multi-class ima

Yechan Kim 8 Oct 29, 2022
Unity Propagation in Bayesian Networks Handling Inconsistency via Unity Smoothing

This repository contains the scripts needed to generate the results from the paper Unity Propagation in Bayesian Networks Handling Inconsistency via U

0 Jan 19, 2022
Makes patches from huge resolution .svs slide files using openslide

openslide_patcher Makes patches from huge resolution .svs slide files using openslide Example collage I made from outputs:

2 Dec 23, 2021
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Phil Wang 180 Jan 05, 2023
Tutorial repo for an end-to-end Data Science project

End-to-end Data Science project This is the repo with the notebooks, code, and additional material used in the ITI's workshop. The goal of the session

Deena Gergis 127 Dec 30, 2022
Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

Luca Moschella 520 Dec 30, 2022
This project contains an implemented version of Face Detection using OpenCV and Mediapipe. This is a code snippet and can be used in projects.

Live-Face-Detection Project Description: In this project, we will be using the live video feed from the camera to detect Faces. It will also detect so

Hassan Shahzad 3 Oct 02, 2021
LSTM and QRNN Language Model Toolkit for PyTorch

LSTM and QRNN Language Model Toolkit This repository contains the code used for two Salesforce Research papers: Regularizing and Optimizing LSTM Langu

Salesforce 1.9k Jan 08, 2023
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 01, 2023
The implementation of PEMP in paper "Prior-Enhanced Few-Shot Segmentation with Meta-Prototypes"

Prior-Enhanced network with Meta-Prototypes (PEMP) This is the PyTorch implementation of PEMP. Overview of PEMP Meta-Prototypes & Adaptive Prototypes

Jianwei ZHANG 8 Oct 14, 2021
DeepDiffusion: Unsupervised Learning of Retrieval-adapted Representations via Diffusion-based Ranking on Latent Feature Manifold

DeepDiffusion Introduction This repository provides the code of the DeepDiffusion algorithm for unsupervised learning of retrieval-adapted representat

4 Nov 15, 2022
Knowledge Management for Humans using Machine Learning & Tags

HyperTag HyperTag helps humans intuitively express how they think about their files using tags and machine learning.

Ravn Tech, Inc. 165 Nov 04, 2022