Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention

Overview

E(n)-Equivariant Transformer (wip)

Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention.

Install

$ pip install En-transformer

Usage

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    dim = 512,
    depth = 4,
    dim_head = 64,
    heads = 8,
    edge_dim = 4,
    fourier_features = 2
)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = model(feats, coors, edges)  # (1, 16, 512), (1, 16, 3)

Todo

  • masking
  • neighborhoods by radius

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Checkpoint sequential segments should equal number of layers instead of 1?

    Checkpoint sequential segments should equal number of layers instead of 1?

    https://github.com/lucidrains/En-transformer/blob/a37e635d93a322cafdaaf829397c601350b23e5b/en_transformer/en_transformer.py#L527

    Looking at the source code here: https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint_sequential

    opened by aced125 2
  • On rotary embeddings

    On rotary embeddings

    Hi @lucidrains, thank you for your amazing work; big fan! I had a quick question on the usage of this repository.

    Based on my understanding, rotary embeddings are a drop-in replacement for the original sinusoidal or learnt PEs in Transformers for sequential data, as in NLP or other temporal applications. If my application is not on sequential data, is there a reason why I should still use rotary embeddings?

    E.g. for molecular datasets such as QM9 (from the En-GNNs paper), would it make sense to have rotary embeddings?

    opened by chaitjo 1
  • Is this line required?

    Is this line required?

    https://github.com/lucidrains/En-transformer/blob/7247e258fab953b2a8b5a73b8dfdfb72910711f8/en_transformer/en_transformer.py#L159

    Is this line required? Does line 157, two lines above, make this line redundant?

    opened by aced125 1
  • Performance drop with checkpointing update

    Performance drop with checkpointing update

    I see a drop in performance (higher loss) when I update checkpointing from checkpoint_sequential(self.layers, 1, inp) to checkpoint_sequential(self.layers, len(self.layers), inp). Is this expected?

    opened by heiidii 0
  • varying number of nodes

    varying number of nodes

    @lucidrains Thank you for your efficient implementation. I was wondering how to use this implementation for the dataset when the number of nodes in each graph is not the same? For example, the datasets of small molecules.

    opened by mohaiminul2810 1
  • Edge model/rep

    Edge model/rep

    Hi,

    Thank you for providing this version of the EnGNN model. This is not really an issue just a query. The original model as implemented here (https://github.com/vgsatorras/egnn) has 3 main steps per layer: edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) h, agg = self.node_model(h, edge_index, edge_feat, node_attr) I am interested in the edge_feat and was wondering what would be an equivalent edge representation in your implementation. Line 335 in EnTransformer.py: qk = self.edge_mlp(qk) seems like the best candidate. Thanks, Pooja

    opened by heiidii 1
  • efficient implementation

    efficient implementation

    Hi, I wonder if relative distances and coordinates can be handled more efficiently using memory efficient attention as in " Self-attention Does Not Need O(n^2) Memory". It is straightforward for the scalar part.

    opened by amrhamedp 2
Releases(1.0.2)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Use .csv files to record, play and evaluate motion capture data.

Purpose These scripts allow you to record mocap data to, and play from .csv files. This approach facilitates parsing of body movement data in statisti

21 Dec 12, 2022
Official repo for BMVC2021 paper ASFormer: Transformer for Action Segmentation

ASFormer: Transformer for Action Segmentation This repo provides training & inference code for BMVC 2021 paper: ASFormer: Transformer for Action Segme

42 Dec 23, 2022
Sharpness-Aware Minimization for Efficiently Improving Generalization

Sharpness-Aware-Minimization-TensorFlow This repository provides a minimal implementation of sharpness-aware minimization (SAM) (Sharpness-Aware Minim

Sayak Paul 54 Dec 08, 2022
A object detecting neural network powered by the yolo architecture and leveraging the PyTorch framework and associated libraries.

Yolo-Powered-Detector A object detecting neural network powered by the yolo architecture and leveraging the PyTorch framework and associated libraries

Luke Wilson 1 Dec 03, 2021
A library of extension and helper modules for Python's data analysis and machine learning libraries.

Mlxtend (machine learning extensions) is a Python library of useful tools for the day-to-day data science tasks. Sebastian Raschka 2014-2020 Links Doc

Sebastian Raschka 4.2k Jan 02, 2023
《Single Image Reflection Removal Beyond Linearity》(CVPR 2019)

Single-Image-Reflection-Removal-Beyond-Linearity Paper Single Image Reflection Removal Beyond Linearity. Qiang Wen, Yinjie Tan, Jing Qin, Wenxi Liu, G

Qiang Wen 51 Jun 24, 2022
Code for paper Novel View Synthesis via Depth-guided Skip Connections

Novel View Synthesis via Depth-guided Skip Connections Code for paper Novel View Synthesis via Depth-guided Skip Connections @InProceedings{Hou_2021_W

8 Mar 14, 2022
IndoNLI: A Natural Language Inference Dataset for Indonesian

IndoNLI: A Natural Language Inference Dataset for Indonesian This is a repository for data and code accompanying our EMNLP 2021 paper "IndoNLI: A Natu

15 Feb 10, 2022
Code for CVPR2021 "Visualizing Adapted Knowledge in Domain Transfer". Visualization for domain adaptation. #explainable-ai

Visualizing Adapted Knowledge in Domain Transfer @inproceedings{hou2021visualizing, title={Visualizing Adapted Knowledge in Domain Transfer}, auth

Yunzhong Hou 80 Dec 25, 2022
CoRe: Contrastive Recurrent State-Space Models

CoRe: Contrastive Recurrent State-Space Models This code implements the CoRe model and reproduces experimental results found in Robust Robotic Control

Apple 21 Aug 11, 2022
Code, pre-trained models and saliency results for the paper "Boosting RGB-D Saliency Detection by Leveraging Unlabeled RGB Images".

Boosting RGB-D Saliency Detection by Leveraging Unlabeled RGB This repository is the official implementation of the paper. Our results comming soon in

Xiaoqiang Wang 8 May 22, 2022
Source code of all the projects of Udacity Self-Driving Car Engineer Nanodegree.

self-driving-car In this repository I will share the source code of all the projects of Udacity Self-Driving Car Engineer Nanodegree. Hope this might

Andrea Palazzi 2.4k Dec 29, 2022
The Instructed Glacier Model (IGM)

The Instructed Glacier Model (IGM) Overview The Instructed Glacier Model (IGM) simulates the ice dynamics, surface mass balance, and its coupling thro

27 Dec 16, 2022
PAthological QUpath Obsession - QuPath and Python conversations

PAQUO: PAthological QUpath Obsession Welcome to paquo 👋 , a library for interacting with QuPath from Python. paquo's goal is to provide a pythonic in

Bayer AG 60 Dec 31, 2022
The code for replicating the experiments from the LFI in SSMs with Unknown Dynamics paper.

Likelihood-Free Inference in State-Space Models with Unknown Dynamics This package contains the codes required to run the experiments in the paper. Th

Alex Aushev 0 Dec 27, 2021
The implementation of the paper "A Deep Feature Aggregation Network for Accurate Indoor Camera Localization".

A Deep Feature Aggregation Network for Accurate Indoor Camera Localization This is the PyTorch implementation of our paper "A Deep Feature Aggregation

9 Dec 09, 2022
CZU-MHAD: A multimodal dataset for human action recognition utilizing a depth camera and 10 wearable inertial sensors

CZU-MHAD: A multimodal dataset for human action recognition utilizing a depth camera and 10 wearable inertial sensors   In order to facilitate the res

yujmo 11 Dec 12, 2022
A embed able annotation tool for end to end cross document co-reference

CoRefi CoRefi is an emebedable web component and stand alone suite for exaughstive Within Document and Cross Document Coreference Anntoation. For a de

PythicCoder 39 Dec 12, 2022
A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX

Foolbox Native: Fast adversarial attacks to benchmark the robustness of machine learning models in PyTorch, TensorFlow, and JAX Foolbox is a Python li

Bethge Lab 2.4k Dec 25, 2022
Code for the paper: Sketch Your Own GAN

Sketch Your Own GAN Project | Paper | Youtube Our method takes in one or a few hand-drawn sketches and customizes an off-the-shelf GAN to match the in

677 Dec 28, 2022