EGNN - Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch

Overview

EGNN - Pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc.

Install

$ pip install egnn-pytorch

Usage

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)

feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)

With edges

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • training batch size

    training batch size

    Dear authors,

    thanks for your great work! I saw your example, which is easy to understand. But I notice that during training, in each iteration, it seems it supports the case where batch-size > 1, but all the graphs have the same adj_mat. do you have better solution for that? thanks

    opened by futianfan 6
  • Import Error when torch_geometric is not available

    Import Error when torch_geometric is not available

    https://github.com/lucidrains/egnn-pytorch/blob/e35510e1be94ee9f540bf2ffea49cd63578fe473/egnn_pytorch/egnn_pytorch.py#L413

    A small problem, this Tensor is not defined.

    Thanks for your work.

    opened by zrt 4
  • About aggregations in EGNN_sparse

    About aggregations in EGNN_sparse

    Hi, thanks for your great work!

    I have a question on how aggregations are computed for node embedding and coordinate embedding. In the paper, the aggregation for node embedding is computed over its neighbors, while the aggregation for coordinate embedding is computed over is computed over all others. However, in EGNN_sparse, I didn't notice such difference in aggregations.

    I guess it is because computing all-pair messages for coordinate embedding makes 'sparse' meaningless, but I would like to double-check to see if I get this correctly. So anyway, did you do this intentionally? Or did I miss something?

    My appreciation.

    opened by simon1727 4
  • Few queries on the implementation

    Few queries on the implementation

    Hi - fast work coding these things up, as usual! Looking at the paper and your code, you're not using squared distance for the edge weighting. Is that intentional? Also, it looks like you are adding the old feature vectors to the new ones rather than taking the new vectors directly from the fully connected net - is that also an intentional change from the paper?

    opened by denjots 3
  • Fix PyG problems. add exmaple for point cloud denoising

    Fix PyG problems. add exmaple for point cloud denoising

    • Fixed some tiny errors in data flows for the PyG layers (dimensions and slices mainly)
    • fixed the EGNN_Sparse_Network so now it works
    • provides example for point cloud denoising (from gaussian masked coordinates), and showcases potential issues:
      • unstable (could be due to nature of data, not sure, but gvp does well on it)
      • not able to beat baseline (in contrast, gvp gets to 0.8 RMSD while this gets to the baseline 1 RMSD but not below it)
    opened by hypnopump 2
  • EGNN_sparse incorrect positional encoding output

    EGNN_sparse incorrect positional encoding output

    Hi, many thanks for the implementation!

    I was quickly checking the code for the pytorch geometric implementation of the EGNN_sparse layer, and I noticed that it expects the first 3 columns in the features to be the coordinates. However, in the update method, features and coordinates are passed in the wrong order.

    https://github.com/lucidrains/egnn-pytorch/blob/375d686c749a685886874baba8c9e0752db5f5be/egnn_pytorch/egnn_pytorch.py#L192

    This may cause problems during learning (think of concatenating several of these layers), as they expect coordinate and feature order to be consistent.

    One can reproduce this behaviour in the following snippet:

    layer = EGNN_sparse(feats_dim=1, pos_dim=3, m_dim=16, fourier_features=0)
    
    R = rot(*torch.rand(3))
    T = torch.randn(1, 1, 3)
    
    feats = torch.randn(16, 1)
    coors = torch.randn(16, 3)
    x1 = torch.cat([coors, feats], dim=-1)
    x2 = torch.cat([(coors @ R + T).squeeze() , feats], dim=-1)
    edge_idxs = (torch.rand(2, 20) * 16).long()
    
    out1 = layer(x=x1, edge_index=edge_idxs)
    out2 = layer(x=x2, edge_index=edge_idxs)
    

    After fixing the order of these arguments in the update method then the layer behaves as expected (output features are equivariant, and coordinate features are equivariant upon se(3) transformation)

    opened by josejimenezluna 2
  • Nan Values after stacking multiple layers

    Nan Values after stacking multiple layers

    Hi Lucid!!

    I find that when stacking multiple layers the output from the model rapidly goes to Nan. I suspect it may be related to the weights used for initialization.

    Here is a minimal working example:

    Make some data:

        import numpy as np
        import torch
        from egnn_pytorch import EGNN
        
        torch.set_default_dtype(torch.double)
    
        zline = np.arange(0, 2, 0.05)
        xline = np.sin(zline * 2 * np.pi) 
        yline = np.cos(zline * 2 * np.pi)
        points = np.array([xline, yline, zline])
        geom = torch.tensor(points.transpose())[None,:]
        feat = torch.randint(0, 20, (1, geom.shape[1],1))
    

    Make a model:

        class ResEGNN(torch.nn.Module):
            def __init__(self, depth = 2, dims_in = 1):
                super().__init__()
                self.layers = torch.nn.ModuleList([EGNN(dim = dims_in) for i in range(depth)])
            
            def forward(self, geom, feat):
                for layer in self.layers:
                    feat, geom = layer(feat, geom)
                return geom
    

    Run model for varying depths:

        for i in range(10):
            model = ResEGNN(depth = i)
            pred = model(geom, feat)
            mean_absolute_value  = torch.abs(pred).mean()
            print("Order of predictions {:.2f}".format(np.log(mean_absolute_value.detach().numpy())))
    

    Output : Order of predictions -0.29 Order of predictions 0.05 Order of predictions 6.65 Order of predictions 21.38 Order of predictions 78.25 Order of predictions 302.71 Order of predictions 277.38 Order of predictions nan Order of predictions nan Order of predictions nan

    opened by brennanaba 2
  • Edge features thrown out

    Edge features thrown out

    Hi, thanks for this implementation!

    I was wondering if the pytorch-geometric implementation of this architecture is throwing the edge features out by mistake, as seen here

    https://github.com/lucidrains/egnn-pytorch/blob/1b8320ade1a89748e4042ae448626652f1c659a1/egnn_pytorch/egnn_pytorch.py#L148-L151

    Or maybe my understanding is wrong? Cheers,

    opened by josejimenezluna 2
  • solve ij -> i bottleneck in sparse version

    solve ij -> i bottleneck in sparse version

    I don't recommend normalizing the weights nor the coords.

    • The weights are the coefficient that multiplies the delta in the i->j direction
    • the coords are the deltas in the i->j direction Can't see the advantage of normalizing them beyond a naive stabilization that might affect the convergence properties by needing more layers due to the limited transformation that a layer will be able to do.

    It works fine for denoising without normalization (the unstability might come from huge outliers, but then tuning the learning rate or clipping the gradients might be of help.)

    opened by hypnopump 0
  • Questions about the EGNN code

    Questions about the EGNN code

    Recently, I've tried to read EGNN paper and study your EGNN code. Actually, I had hard time to understand both paper and code because my major is not computer science. When studying your code, I realize that the shape of hidden_out and the shape of kwargs["x"] must be same to perform add operation (becaus of residual connection) in the class EGNN_sparse forward method. How can I increase or decrease the hidden dimension size of x?

    I would like to get some advice.

    Thanks for your consideration in this regard.

    opened by Byun-jinyoung 0
  • Wrong edge_index size hint in  class EGNN_Sparse of pyg version

    Wrong edge_index size hint in class EGNN_Sparse of pyg version

    Hi, I found there may be a little mistake. In the input hint of class EGNN_Sparse of pyg version, the size of edge_index is (n_edges, 2). However, it should be (2, n_edges). Otherwise, the distance calculation will be not correct. """ Inputs: * x: (n_points, d) where d is pos_dims + feat_dims * edge_index: (n_edges, 2) * edge_attr: tensor (n_edges, n_feats) excluding basic distance feats. * batch: (n_points,) long tensor. specifies xloud belonging for each point * angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor. * size: None """

    opened by Layne-Huang 2
  • Exploding Gradients With 4 Layers

    Exploding Gradients With 4 Layers

    I'm using EGNN with 4 layers (where I also do global attention after each layer), and I'm seeing exploding gradients after 90 epochs or so. I'm using techniques discussed earlier (sparse attention matrix, coor_weights_clamp_value, norm_coors), but I'm not sure if there's anything else I should be doing. I'm also not updating the coordinates, so the fix in the pull request doesn't apply.

    opened by cutecows 0
  • Added optional tanh to coors_mlp

    Added optional tanh to coors_mlp

    This removes the NaN bug completely (must also use norm_coors otherwise performance dies)

    The NaN bug comes from the coors_mlp exploding, so forcing values between -1 and 1 prevents this. If coordinates are normalised then performance should not be adversely affected.

    opened by jscant 1
Releases(0.2.6)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
TRIQ implementation

TRIQ Implementation TF-Keras implementation of TRIQ as described in Transformer for Image Quality Assessment. Installation Clone this repository. Inst

Junyong You 115 Dec 30, 2022
SOTA easy to use PyTorch-based DL training library

Easily train or fine-tune SOTA computer vision models from one training repository. SuperGradients Introduction Welcome to SuperGradients, a free open

619 Jan 03, 2023
Abstractive opinion summarization system (SelSum) and the largest dataset of Amazon product summaries (AmaSum). EMNLP 2021 conference paper.

Learning Opinion Summarizers by Selecting Informative Reviews This repository contains the codebase and the dataset for the corresponding EMNLP 2021

Arthur Bražinskas 39 Jan 01, 2023
Accelerated deep learning R&D

Accelerated deep learning R&D PyTorch framework for Deep Learning research and development. It focuses on reproducibility, rapid experimentation, and

Catalyst-Team 3.1k Jan 06, 2023
This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

4 Aug 02, 2022
Real-Time SLAM for Monocular, Stereo and RGB-D Cameras, with Loop Detection and Relocalization Capabilities

ORB-SLAM2 Authors: Raul Mur-Artal, Juan D. Tardos, J. M. M. Montiel and Dorian Galvez-Lopez (DBoW2) 13 Jan 2017: OpenCV 3 and Eigen 3.3 are now suppor

Raul Mur-Artal 7.8k Dec 30, 2022
A check for whether the dependency jobs are all green.

alls-green A check for whether the dependency jobs are all green. Why? Do you have more than one job in your GitHub Actions CI/CD workflows setup? Do

Re:actors 33 Jan 03, 2023
Monify: an Expense tracker Program implemented in a Graphical User Interface that allows users to keep track of their expenses

💳 MONIFY (EXPENSE TRACKER PRO) 💳 Description Monify is an Expense tracker Program implemented in a Graphical User Interface allows users to add inco

Moyosore Weke 1 Dec 14, 2021
DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

DeepSpeed+Megatron trained the world's most powerful language model: MT-530B DeepSpeed is hiring, come join us! DeepSpeed is a deep learning optimizat

Microsoft 8.4k Dec 28, 2022
This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametric Head Model (CVPR 2022)".

HeadNeRF: A Real-time NeRF-based Parametric Head Model This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametr

294 Jan 01, 2023
SAS output to EXCEL converter for Cornell/MIT Language and acquisition lab

CORNELLSASLAB SAS output to EXCEL converter for Cornell/MIT Language and acquisition lab Instructions: This python code can be used to convert SAS out

2 Jan 26, 2022
An implementation for the loss function proposed in Decoupled Contrastive Loss paper.

Decoupled-Contrastive-Learning This repository is an implementation for the loss function proposed in Decoupled Contrastive Loss paper. Requirements P

Ramin Nakhli 71 Dec 04, 2022
Official PyTorch implementation of the paper "Self-Supervised Relational Reasoning for Representation Learning", NeurIPS 2020 Spotlight.

Official PyTorch implementation of the paper: "Self-Supervised Relational Reasoning for Representation Learning" (2020), Patacchiola, M., and Storkey,

Massimiliano Patacchiola 135 Jan 03, 2023
Bottom-up Human Pose Estimation

Introduction This is the official code of Rethinking the Heatmap Regression for Bottom-up Human Pose Estimation. This paper has been accepted to CVPR2

108 Dec 01, 2022
Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021

Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021 [WIP] The code for CVPR 2021 paper 'Disentangled Cycle Consistency for H

ChongjianGE 94 Dec 11, 2022
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Malik Boudiaf 138 Dec 12, 2022
Learning Confidence for Out-of-Distribution Detection in Neural Networks

Learning Confidence Estimates for Neural Networks This repository contains the code for the paper Learning Confidence for Out-of-Distribution Detectio

235 Jan 05, 2023
Implementation of TabTransformer, attention network for tabular data, in Pytorch

Tab Transformer Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's bread

Phil Wang 420 Jan 05, 2023
Torch-based tool for quantizing high-dimensional vectors using additive codebooks

Trainable multi-codebook quantization This repository implements a utility for use with PyTorch, and ideally GPUs, for training an efficient quantizer

Daniel Povey 41 Jan 07, 2023
Code for MSc Quantitative Finance Dissertation

MSc Dissertation Code ReadMe Sector Volatility Prediction Performance Using GARCH Models and Artificial Neural Networks Curtis Nybo MSc Quantitative F

2 Dec 01, 2022