Implementation of Invariant Point Attention, used for coordinate refinement in the structure module of Alphafold2, as a standalone Pytorch module

Overview

Invariant Point Attention - Pytorch

Implementation of Invariant Point Attention as a standalone module, which was used in the structure module of Alphafold2 for coordinate refinement.

  • write up a test for invariance under rotation
  • enforce float32 for certain operations

Install

$ pip install invariant-point-attention

Usage

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,                  # single (and pairwise) representation dimension
    heads = 8,                 # number of attention heads
    scalar_key_dim = 16,       # scalar query-key dimension
    scalar_value_dim = 16,     # scalar value dimension
    point_key_dim = 4,         # point query-key dimension
    point_value_dim = 4        # point value dimension
)

single_repr   = torch.randn(1, 256, 64)      # (batch x seq x dim)
pairwise_repr = torch.randn(1, 256, 256, 64) # (batch x seq x seq x dim)
mask          = torch.ones(1, 256).bool()    # (batch x seq)

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)  # (batch x seq x rot1 x rot2) - example is identity
translations  = torch.zeros(1, 256, 3) # translation, also identity for example

attn_out = attn(
    single_repr,
    pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use this module without the pairwise representations, which is very specific to the Alphafold2 architecture.

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,
    heads = 8,
    require_pairwise_repr = False   # set this to False to use the module without pairwise representations
)

seq           = torch.randn(1, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

attn_out = attn(
    seq,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use one IPA-based transformer block, which is an IPA followed by a feedforward. By default it will use post-layernorm as done in the official code, but you can also try pre-layernorm by setting post_norm = False

import torch
from torch import nn
from einops import repeat
from invariant_point_attention import IPABlock

block = IPABlock(
    dim = 64,
    heads = 8,
    scalar_key_dim = 16,
    scalar_value_dim = 16,
    point_key_dim = 4,
    point_value_dim = 4
)

seq           = torch.randn(1, 256, 64)
pairwise_repr = torch.randn(1, 256, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

block_out = block(
    seq,
    pairwise_repr = pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

updates = nn.Linear(64, 6)(block_out)
quaternion_update, translation_update = updates.chunk(2, dim = -1) # (1, 256, 3), (1, 256, 3)

# apply updates to rotations and translations for the next iteration

Citations

@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}
Comments
  • Computing point dist - use cartesian dimension instead of hidden dimension

    Computing point dist - use cartesian dimension instead of hidden dimension

    https://github.com/lucidrains/invariant-point-attention/blob/2f1fb7ca003d9c94d4144d1f281f8cbc914c01c2/invariant_point_attention/invariant_point_attention.py#L130

    I think it should be dim=-1, thus using the cartesian (xyz) axis, rather than dim=-2, which uses the hidden dimension.

    opened by aced125 3
  • In-place rotation detach not allowed

    In-place rotation detach not allowed

    Hi, this is probably highly version-dependent (I have pytorch=1.11.0, pytorch3d=0.7.0 nightly), but I thought I'd report it. Torch doesn't like the in-place detach of the rotation tensor. Full stack trace (from denoise.py):

    Traceback (most recent call last):
      File "denoise.py", line 56, in <module>
        denoised_coords = net(
      File "/home/pi-user/miniconda3/envs/piai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/pi-user/invariant-point-attention/invariant_point_attention/invariant_point_attention.py", line 336, in forward
        rotations.detach_()
    RuntimeError: Can't detach views in-place. Use detach() instead. If you are using DistributedDataParallel (DDP) for training, and gradient_as_bucket_view is set as True, gradients are views of DDP buckets, and hence detach_() cannot be called on these gradients. To fix this error, please refer to the Optimizer.zero_grad() function in torch/optim/optimizer.py as the solution.
    

    Switching to rotations = rotations.detach() seems to behave correctly (tested in denoise.py and my own code). I'm not totally sure if this allocates a separate tensor, or just creates a new node pointing to the same data.

    opened by sidnarayanan 1
  • Report a bug that causes instability in training

    Report a bug that causes instability in training

    Hi, I would like to report a bug in the rotation, that causes instability in training. https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L322

    The IPA Transformer is similar to the structure module in AF2, where the recycling is used. Note that we usually detach the gradient of rotation, which may causes instability during training. The reason is that the gradient of rotation would update the rotation during back propagation, which results in the instability based on experiments. Therefore we usually detach the rotation to dispel the updating effect of gradient descent. I have seen you do this in your alphafold2 repo (https://github.com/lucidrains/alphafold2).

    If you think this is a problem, please let me know. I am happy to submit a pr to fix that.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • Subtle mistake in the implementation

    Subtle mistake in the implementation

    Hi. Thanks for your implementation. It is very helpful. However, I find that you miss the dropout in the IPAModule.

    https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L239

    In the alphafold2 supplementary, the dropout is nested in the layer norm, which also holds true in the layer norm at transition layer (line 9 in the figure below). image

    If you think this is a problem, please let me know. I will submit a pr to fix it. Thanks again for sharing such an amazing repo.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • change quaternions update as original alphafold2

    change quaternions update as original alphafold2

    In the original alphafold2 IPA module, pure-quaternion (without real part) description is used for quaternion update. This can be broken down to the residual-update-like formulation. But in this code you use (1, a, b, c) style quaternion so I believe the quaternion update should be done as a simple multiply update. As far as I have tested, the loss seems to go down more efficiently with the modification.

    opened by ShintaroMinami 1
  • #126 maybe omit the 'self.point_attn_logits_scale'?

    #126 maybe omit the 'self.point_attn_logits_scale'?

    Hi luci:

    I read the original paper and compare it to your implement, found one place might be some mistake:

    #126. attn_logits_points = -0.5 * (point_dist * point_weights).sum(dim = -1),

    I thought it should be attn_logits_points = -0.5 * (point_dist * point_weights * self.point_attn_logits_scale).sum(dim = -1)

    Thanks for your sharing!

    opened by CiaoHe 1
  • Application of Invariant point attention : preserver part of structure.

    Application of Invariant point attention : preserver part of structure.

    Hi, lucidrian. First of all really thanks for your work!

    I have a question, how can I change(denoise) the structure only in the region I want, how do I do it? (denoise.py)

    opened by hw-protein 0
  • Equivariance test for IPA Transformer

    Equivariance test for IPA Transformer

    @lucidrains I would like to ask about the equivariance of the transformer (not IPA blocks). I wonder if you checked for the equivariance of the output when you allow the transformation of local points to global points using the updated quaternions and translations. I am not sure why this test fails in my case.

    opened by amrhamedp 1
Owner
Phil Wang
Working with Attention
Phil Wang
Official Implementation and Dataset of "PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask and Group-Level Consistency", CVPR 2021

Portrait Photo Retouching with PPR10K Paper | Supplementary Material PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask an

184 Dec 11, 2022
This is the official PyTorch implementation of our paper: "Artistic Style Transfer with Internal-external Learning and Contrastive Learning".

Artistic Style Transfer with Internal-external Learning and Contrastive Learning This is the official PyTorch implementation of our paper: "Artistic S

51 Dec 20, 2022
Winners of DrivenData's Overhead Geopose Challenge

Winners of DrivenData's Overhead Geopose Challenge

DrivenData 22 Aug 04, 2022
PyTorch implementation of ICLR 2022 paper PiCO: Contrastive Label Disambiguation for Partial Label Learning

PiCO: Contrastive Label Disambiguation for Partial Label Learning This is a PyTorch implementation of ICLR 2022 Oral paper PiCO; also see our Project

็Ž‹็š“ๆณข 147 Jan 07, 2023
The project of phase's key role in complex and real NN

Phase-in-NN This is the code for our project at Princeton (co-authors: Yuqi Nie, Hui Yuan). The paper title is: "Neural Network is heterogeneous: Phas

YuqiNie-lab 1 Nov 04, 2021
Source Code for ICSE 2022 Paper - ``Can We Achieve Fairness Using Semi-Supervised Learning?''

Fair-SSL Source Code for ICSE 2022 Paper - Can We Achieve Fairness Using Semi-Supervised Learning? Ethical bias in machine learning models has become

1 Dec 18, 2021
One line to host them all. Bootstrap your image search case in minutes.

One line to host them all. Bootstrap your image search case in minutes. Survey NOW gives the world access to customized neural image search in just on

Jina AI 403 Dec 30, 2022
High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.

TL;DR Ignite is a high-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently. Click on the image to

4.2k Jan 01, 2023
A Pytorch implementation of "LegoNet: Efficient Convolutional Neural Networks with Lego Filters" (ICML 2019).

LegoNet This code is the implementation of ICML2019 paper LegoNet: Efficient Convolutional Neural Networks with Lego Filters Run python train.py You c

YangZhaohui 140 Sep 26, 2022
LocUNet is a deep learning method to localize a UE based solely on the reported signal strengths from a set of BSs.

LocUNet LocUNet is a deep learning method to localize a UE based solely on the reported signal strengths from a set of BSs. The method utilizes accura

4 Oct 05, 2022
Representing Long-Range Context for Graph Neural Networks with Global Attention

Graph Augmentation Graph augmentation/self-supervision/etc. Algorithms gcn gcn+virtual node gin gin+virtual node PNA GraphTrans Augmentation methods N

UC Berkeley RISE 67 Dec 30, 2022
๐Ÿš€ An end-to-end ML applications using PyTorch, W&B, FastAPI, Docker, Streamlit and Heroku

๐Ÿš€ An end-to-end ML applications using PyTorch, W&B, FastAPI, Docker, Streamlit and Heroku

Made With ML 82 Jun 26, 2022
Minecraft agent to farm resources using reinforcement learning

BarnyardBot CS 175 group project using Malmo download BarnyardBot.py into the python examples directory and run 'python BarnyardBot.py' in the console

0 Jul 26, 2022
Pose estimation for iOS and android using TensorFlow 2.0

๐Ÿ’ƒ Mobile 2D Single Person (Or Your Own Object) Pose Estimation for TensorFlow 2.0 This repository is forked from edvardHua/PoseEstimationForMobile wh

tucan9389 165 Nov 16, 2022
Official repository for "Action-Based Conversations Dataset: A Corpus for Building More In-Depth Task-Oriented Dialogue Systems"

Action-Based Conversations Dataset (ABCD) This respository contains the code and data for ABCD (Chen et al., 2021) Introduction Whereas existing goal-

ASAPP Research 49 Oct 09, 2022
[CVPR 2021] Official PyTorch Implementation for "Iterative Filter Adaptive Network for Single Image Defocus Deblurring"

IFAN: Iterative Filter Adaptive Network for Single Image Defocus Deblurring Checkout for the demo (GUI/Google Colab)! The GUI version might occasional

Junyong Lee 173 Dec 30, 2022
Fine-tuning StyleGAN2 for Cartoon Face Generation

Cartoon-StyleGAN ๐Ÿ™ƒ : Fine-tuning StyleGAN2 for Cartoon Face Generation Abstract Recent studies have shown remarkable success in the unsupervised imag

Jihye Back 520 Jan 04, 2023
NAVER BoostCamp Final Project

CV 14์กฐ final project Super Resolution and Deblur module Inference code & Pretrained weight Repo SwinIR Deblur ์‹คํ–‰ ๋ฐฉ๋ฒ• streamlit run WebServer/Server_SRD

JiSeong Kim 5 Sep 06, 2022
Implementation of Artificial Neural Network Algorithm

Artificial Neural Network This repository contain implementation of Artificial Neural Network Algorithm in several programming languanges and framewor

Resha Dwika Hefni Al-Fahsi 1 Sep 14, 2022
This repository contains a Ruby API for utilizing TensorFlow.

tensorflow.rb Description This repository contains a Ruby API for utilizing TensorFlow. Linux CPU Linux GPU PIP Mac OS CPU Not Configured Not Configur

somatic labs 825 Dec 26, 2022