Implementation of self-attention mechanisms for general purpose. Focused on computer vision modules. Ongoing repository.

Overview

Self-attention building blocks for computer vision applications in PyTorch

Implementation of self attention mechanisms for computer vision in PyTorch with einsum and einops. Focused on computer vision self-attention modules.

Install it via pip

It would be nice to install pytorch in your enviroment, in case you don't have a GPU.

pip install self-attention-cv

Related articles

More articles are on the way.

Code Examples

Multi-head attention

import torch
from self_attention_cv import MultiHeadSelfAttention

model = MultiHeadSelfAttention(dim=64)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x, mask)

Axial attention

import torch
from self_attention_cv import AxialAttentionBlock
model = AxialAttentionBlock(in_channels=256, dim=64, heads=8)
x = torch.rand(1, 256, 64, 64)  # [batch, tokens, dim, dim]
y = model(x)

Vanilla Transformer Encoder

import torch
from self_attention_cv import TransformerEncoder
model = TransformerEncoder(dim=64,blocks=6,heads=8)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x,mask)

Vision Transformer with/without ResNet50 backbone for image classification

import torch
from self_attention_cv import ViT, ResNet50ViT

model1 = ResNet50ViT(img_dim=128, pretrained_resnet=False, 
                        blocks=6, num_classes=10, 
                        dim_linear_block=256, dim=256)
# or
model2 = ViT(img_dim=256, in_channels=3, patch_dim=16, num_classes=10,dim=512)
x = torch.rand(2, 3, 256, 256)
y = model2(x) # [2,10]

A re-implementation of Unet with the Vision Transformer encoder

import torch
from self_attention_cv.transunet import TransUnet
a = torch.rand(2, 3, 128, 128)
model = TransUnet(in_channels=3, img_dim=128, vit_blocks=8,
vit_dim_linear_mhsa_block=512, classes=5)
y = model(a) # [2, 5, 128, 128]

Bottleneck Attention block

import torch
from self_attention_cv.bottleneck_transformer import BottleneckBlock
inp = torch.rand(1, 512, 32, 32)
bottleneck_block = BottleneckBlock(in_channels=512, fmap_size=(32, 32), heads=4, out_channels=1024, pooling=True)
y = bottleneck_block(inp)

Position embeddings are also available

1D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import AbsPosEmb1D,RelPosEmb1D

model = AbsPosEmb1D(tokens=20, dim_head=64)
# batch heads tokens dim_head
q = torch.rand(2, 3, 20, 64)
y1 = model(q)

model = RelPosEmb1D(tokens=20, dim_head=64, heads=3)
q = torch.rand(2, 3, 20, 64)
y2 = model(q)

2D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import RelPosEmb2D
dim = 32  # spatial dim of the feat map
model = RelPosEmb2D(
    feat_map_size=(dim, dim),
    dim_head=128)

q = torch.rand(2, 4, dim*dim, 128)
y = model(q)

References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.
  2. Wang, H., Zhu, Y., Green, B., Adam, H., Yuille, A., & Chen, L. C. (2020, August). Axial-deeplab: Stand-alone axial-attention for panoptic segmentation. In European Conference on Computer Vision (pp. 108-126). Springer, Cham.
  3. Srinivas, A., Lin, T. Y., Parmar, N., Shlens, J., Abbeel, P., & Vaswani, A. (2021). Bottleneck Transformers for Visual Recognition. arXiv preprint arXiv:2101.11605.
  4. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
Comments
  • Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that?I `Traceback (most recent call last): File "self-attention-cv/tests/test_TransUnet.py", line 14, in test_TransUnet() File "/self-attention-cv/tests/test_TransUnet.py", line 11, in test_TransUnet y = model(a) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "self-attention-cv\self_attention_cv\transunet\trans_unet.py", line 88, in forward y = self.project_patches_back(y) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward return F.linear(input, self.weight, self.bias) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\functional.py", line 1692, in linear output = input.matmul(weight.t()) RuntimeError: mat1 dim 1 must match mat2 dim 0

    Process finished with exit code 1 ` Could you please help me solve it? Thank you.

    opened by yezhengjie 7
  • TransUNet - Why is the patch_dim set to 1?

    TransUNet - Why is the patch_dim set to 1?

    Hi,

    Can you please explain why is the patch_dim set to 1 in TransUNet class? Thank you in advance!

    https://github.com/The-AI-Summer/self-attention-cv/blob/8280009366b633921342db6cab08da17b46fdf1c/self_attention_cv/transunet/trans_unet.py#L54

    opened by dsitnik 7
  • Question: Sliding Window Module for Transformer3dSeg Object

    Question: Sliding Window Module for Transformer3dSeg Object

    I was wondering whether or not you've implemented an example using the network in a 3d medical segmentation task and/or use case? If this network only exports the center slice of a patch then we would need a wrapper function to iterate through all patches in an image to get the final prediction for the entire volume. From the original paper, I assume they choose 10 patches at random from an image during training, but it's not too clear how they pieced everything together during testing.

    Your thoughts on this would be greatly appreciated!

    See: https://github.com/The-AI-Summer/self-attention-cv/blob/33ddf020d2d9fb9c4a4a3b9938383dc9b7405d8c/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L10

    opened by jmarsil 5
  • ResNet + Pyramid Vision Transformer Version 2

    ResNet + Pyramid Vision Transformer Version 2

    Thank you for your work with a clear explanation. As you know, ViT doesn't work on small datasets and I am implementing ResNet34 with Pyramid Vision Transformer Version 2 to make it better. The architecture of ViT and PVT V2 is completely different. Could you provide me some help to implement it? please

    opened by khawar-islam 3
  • Request for Including UNETR

    Request for Including UNETR

    Thanks for great work ! I noticed nice implementation of this paper (https://arxiv.org/abs/2103.10504) here:

    https://github.com/tamasino52/UNETR/blob/main/unetr.py

    It would be great if this can also be included in your repo, since it comes with lots of other great features. So we can explore more.

    Thanks ~

    opened by Siyuan89 3
  • ImageNet Pretrained TimesFormer

    ImageNet Pretrained TimesFormer

    I see you have recently added the TimesFormer model to this repository. In the paper, they initialize their model weights from ImageNet pretrained weights of ViT. Does your implementation offer this too? Thanks!

    opened by RaivoKoot 3
  • Do the encoder modules incorporate positional encoding?

    Do the encoder modules incorporate positional encoding?

    I am wondering if I use say the LinformerEncoder if I have to add the position encoding or if that's already done? From the source files it doesn't seem to be there, but I'm not sure how to include the position encoding as they seem to need the query which isn't available when just passing data directly to the LinformerEncoder. I very well may be missing something any help would be great. Perhaps an example using positional encoding would be good.

    opened by jfkback 3
  • use AxialAttention on gpu

    use AxialAttention on gpu

    I try to use AxialAttention on gpu, but I get a mistake.Can you give me some tips about using AxialAttention on gpu. Thanks! mistake: RuntimeError: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0

    opened by Iverson-Al 2
  • Axial attention

    Axial attention

    What is the meaning of qkv_channels? https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/axial_attention_deeplab/axial_attention.py#L32

    opened by Jayden9912 1
  • Convolution-Free Medical Image Segmentation using Transformers

    Convolution-Free Medical Image Segmentation using Transformers

    Thank you very much for your contribution. As a novice, I have a doubt. In tranf3dseg, the output of the model is the prediction segmentation of the center patch, so how can I get the segmentation of the whole input image? I am looking forward to any reply.

    opened by WinsaW 1
  • Regression with attention

    Regression with attention

    Hello!

    thanks for sharing this nice repo :)

    I'm trying to use ViT to do regression on images. I'd like to predict 6 floats per image.

    My understanding is that I'd need to simply define the network as

    vit = ViT(img_dim=128,
                   in_channels=3,
                   patch_dim=16,
                   num_classes=6,
                   dim=512)
    

    and during training call

    vit(x)
    

    and compute the loss as MSE instead of CE.

    The network actually runs but it doesn't seem to converge. Is there something obvious I am missing?

    many thanks!

    opened by alemelis 1
  • Segmentation for full image

    Segmentation for full image

    Hi,

    Thank you for your effort and time in implementing this. I have a quick question, I want to get segmentation for full image not just for the middle token, would it be correct to change self.tokens to self.p here:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L66

    and change this:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L94

    to

    y = self.mlp_seg_head(y)

    opened by aqibsaeed 0
Releases(1.2.3)
Owner
AI Summer
Learn Deep Learning and Artificial Intelligence
AI Summer
🔥🔥High-Performance Face Recognition Library on PaddlePaddle & PyTorch🔥🔥

face.evoLVe: High-Performance Face Recognition Library based on PaddlePaddle & PyTorch Evolve to be more comprehensive, effective and efficient for fa

Zhao Jian 3.1k Jan 02, 2023
Training Certifiably Robust Neural Networks with Efficient Local Lipschitz Bounds (Local-Lip)

Training Certifiably Robust Neural Networks with Efficient Local Lipschitz Bounds (Local-Lip) Introduction TL;DR: We propose an efficient and trainabl

17 Dec 01, 2022
Tensorflow AffordanceNet and AffContext implementations

AffordanceNet and AffContext This is tensorflow AffordanceNet and AffContext implementations. Both are implemented and tested with tensorflow 2.3. The

Beatriz Pérez 6 Dec 01, 2022
NUANCED is a user-centric conversational recommendation dataset that contains 5.1k annotated dialogues and 26k high-quality user turns.

NUANCED: Natural Utterance Annotation for Nuanced Conversation with Estimated Distributions Overview NUANCED is a user-centric conversational recommen

Facebook Research 18 Dec 28, 2021
Image Recognition using Pytorch

PyTorch Project Template A simple and well designed structure is essential for any Deep Learning project, so after a lot practice and contributing in

Sarat Chinni 1 Nov 02, 2021
A deep learning based semantic search platform that computes similarity scores between provided query and documents

semanticsearch This is a deep learning based semantic search platform that computes similarity scores between provided query and documents. Documents

1 Nov 30, 2021
Using CNN to mimic the driver based on training data from Torcs

Behavioural-Cloning-in-autonomous-driving Using CNN to mimic the driver based on training data from Torcs. Approach First, the data was collected from

Sudharshan 2 Jan 05, 2022
PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

Salesforce 1.3k Dec 31, 2022
"MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction" (CVPRW 2022) & (Winner of NTIRE 2022 Challenge on Spectral Reconstruction from RGB)

MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction (CVPRW 2022) Yuanhao Cai, Jing Lin, Zudi Lin, Haoqian Wang, Yulun Z

Yuanhao Cai 274 Jan 05, 2023
Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite and .pb from .tflite.

tflite2tensorflow Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite and .pb from .tflite. 1. Supported Layers No. TFLite Layer TF

Katsuya Hyodo 214 Dec 29, 2022
Discriminative Condition-Aware PLDA

DCA-PLDA This repository implements the Discriminative Condition-Aware Backend described in the paper: L. Ferrer, M. McLaren, and N. Brümmer, "A Speak

Luciana Ferrer 31 Aug 05, 2022
Repository for self-supervised landmark discovery

self-supervised-landmarks Repository for self-supervised landmark discovery Requirements pytorch pynrrd (for 3d images) Usage The use of this models i

Riddhish Bhalodia 2 Apr 18, 2022
OpenCVのGrabCut()を利用したセマンティックセグメンテーション向けアノテーションツール(Annotation tool using GrabCut() of OpenCV. It can be used to create datasets for semantic segmentation.)

[Japanese/English] GrabCut-Annotation-Tool GrabCut-Annotation-Tool.mp4 OpenCVのGrabCut()を利用したアノテーションツールです。 セマンティックセグメンテーション向けのデータセット作成にご使用いただけます。 ※Grab

KazuhitoTakahashi 30 Nov 18, 2022
TensorFlow-LiveLessons - "Deep Learning with TensorFlow" LiveLessons

TensorFlow-LiveLessons Note that the second edition of this video series is now available here. The second edition contains all of the content from th

Deep Learning Study Group 830 Jan 03, 2023
A tensorflow=1.13 implementation of Deconvolutional Networks on Graph Data (NeurIPS 2021)

GDN A tensorflow=1.13 implementation of Deconvolutional Networks on Graph Data (NeurIPS 2021) Abstract In this paper, we consider an inverse problem i

4 Sep 13, 2022
Code for the IJCAI 2021 paper "Structure Guided Lane Detection"

SGNet Project for the IJCAI 2021 paper "Structure Guided Lane Detection" Abstract Recently, lane detection has made great progress with the rapid deve

Jinming Su 27 Dec 08, 2022
Simple renderer for use with MuJoCo (>=2.1.2) Python Bindings.

Viewer for MuJoCo in Python Interactive renderer to use with the official Python bindings for MuJoCo. Starting with version 2.1.2, MuJoCo comes with n

Rohan P. Singh 62 Dec 30, 2022
The official implementation of Autoregressive Image Generation using Residual Quantization (CVPR '22)

Autoregressive Image Generation using Residual Quantization (CVPR 2022) The official implementation of "Autoregressive Image Generation using Residual

Kakao Brain 529 Dec 30, 2022
This is the 3D Implementation of 《Inconsistency-aware Uncertainty Estimation for Semi-supervised Medical Image Segmentation》

CoraNet This is the 3D Implementation of 《Inconsistency-aware Uncertainty Estimation for Semi-supervised Medical Image Segmentation》 Environment pytor

25 Nov 08, 2022