Implementation of Bottleneck Transformer in Pytorch

Overview

Bottleneck Transformer - Pytorch

PyPI version

Implementation of Bottleneck Transformer, SotA visual recognition model with convolution + attention that outperforms EfficientNet and DeiT in terms of performance-computes trade-off, in Pytorch

Install

$ pip install bottleneck-transformer-pytorch

Usage

import torch
from torch import nn
from bottleneck_transformer_pytorch import BottleStack

layer = BottleStack(
    dim = 256,              # channels in
    fmap_size = 64,         # feature map size
    dim_out = 2048,         # channels out
    proj_factor = 4,        # projection factor
    downsample = True,      # downsample on first layer or not
    heads = 4,              # number of heads
    dim_head = 128,         # dimension per head, defaults to 128
    rel_pos_emb = False,    # use relative positional embedding - uses absolute if False
    activation = nn.ReLU()  # activation throughout the network
)

fmap = torch.randn(2, 256, 64, 64) # feature map from previous resnet block(s)

layer(fmap) # (2, 2048, 32, 32)

BotNet

With some simple model surgery off a resnet, you can have the 'BotNet' (what a weird name) for training.

import torch
from torch import nn
from torchvision.models import resnet50

from bottleneck_transformer_pytorch import BottleStack

layer = BottleStack(
    dim = 256,
    fmap_size = 56,        # set specifically for imagenet's 224 x 224
    dim_out = 2048,
    proj_factor = 4,
    downsample = True,
    heads = 4,
    dim_head = 128,
    rel_pos_emb = True,
    activation = nn.ReLU()
)

resnet = resnet50()

# model surgery

backbone = list(resnet.children())

model = nn.Sequential(
    *backbone[:5],
    layer,
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(1),
    nn.Linear(2048, 1000)
)

# use the 'BotNet'

img = torch.randn(2, 3, 224, 224)
preds = model(img) # (2, 1000)

Citations

@misc{srinivas2021bottleneck,
    title   = {Bottleneck Transformers for Visual Recognition}, 
    author  = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
    year    = {2021},
    eprint  = {2101.11605},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • How should I modify the code if the input feature map has unequal width and height?

    How should I modify the code if the input feature map has unequal width and height?

    Assume that the width and height of the feature map are 10 and 8, respectively. Could you please help me to check that the modification about the class RelPosEmb is correct?

    class RelPosEmb(nn.Module): def init( self, fmap_size, dim_head ): super().init() scale = dim_head ** -0.5 self.fmap_size = fmap_size self.scale = scale # self.rel_height = nn.Parameter(torch.randn(fmap_size * 2 - 1, dim_head) * scale) self.rel_height = nn.Parameter(torch.randn(8 * 2 - 1, dim_head) * scale) # self.rel_width = nn.Parameter(torch.randn(fmap_size * 2 - 1, dim_head) * scale) self.rel_width = nn.Parameter(torch.randn(10* 2 - 1, dim_head) * scale)

    def forward(self, q):
        q = rearrange(q, 'b h (x y) d -> b h x y d', x = 8)
        rel_logits_w = relative_logits_1d(q, self.rel_width)
        rel_logits_w = rearrange(rel_logits_w, 'b h x i y j-> b h (x y) (i j)')
    
        q = rearrange(q, 'b h x y d -> b h y x d')
        rel_logits_h = relative_logits_1d(q, self.rel_height)
        rel_logits_h = rearrange(rel_logits_h, 'b h x i y j -> b h (y x) (j i)')
        return rel_logits_w + rel_logits_h
    
    opened by ShuweiShao 4
  • Feature map size

    Feature map size

    hi In my case, the input images size are all different, so the feature map size keeps changing. In this case, how should the fmap_size parameter of BottleStack be set? Is it possible to learn with an unfixed feature map size?

    opened by benlee73 3
  • A little bug.

    A little bug.

    https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/b789de6db39f33854862fbc9bcee27c697cf003c/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py#L16

    It is necessary to specify the equipment here.

    flat_pad = torch.zeros((b, h, l - 1), device = device, dtype = dtype) 
    
    opened by lartpang 1
  • fix inplace operations

    fix inplace operations

    Latest versions of PyTorch throw runtime errors for inplace operations like *= and += on tensors that require gradients. This pull request fixes the issue by replacing them with binary versions.

    opened by AminRezaei0x443 0
  • could you explain the implements of ralative position embedding?

    could you explain the implements of ralative position embedding?

    reference https://github.com/tensorflow/tensor2tensor/blob/5f9dd2db6d7797162e53adf152310ed13e9fc711/tensor2tensor/layers/common_attention.py

    def _generate_relative_positions_matrix(length_q, length_k,
                                            max_relative_position,
                                            cache=False):
      """Generates matrix of relative positions between inputs."""
      if not cache:
        if length_q == length_k:
          range_vec_q = range_vec_k = tf.range(length_q)
        else:
          range_vec_k = tf.range(length_k)
          range_vec_q = range_vec_k[-length_q:]
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
      else:
        distance_mat = tf.expand_dims(tf.range(-length_k+1, 1, 1), 0)
      distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
                                              max_relative_position)
      # Shift values to be >= 0. Each integer still uniquely identifies a relative
      # position difference.
      final_mat = distance_mat_clipped + max_relative_position
      return final_mat
    
    
    def _generate_relative_positions_embeddings(length_q, length_k, depth,
                                                max_relative_position, name,
                                                cache=False):
      """Generates tensor of size [1 if cache else length_q, length_k, depth]."""
      with tf.variable_scope(name):
        relative_positions_matrix = _generate_relative_positions_matrix(
            length_q, length_k, max_relative_position, cache=cache)
        vocab_size = max_relative_position * 2 + 1
        # Generates embedding for each relative position of dimension depth.
        embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
        embeddings = tf.gather(embeddings_table, relative_positions_matrix)
        return embeddings
    
    
    def _relative_attention_inner(x, y, z, transpose):
      """Relative position-aware dot-product attention inner calculation.
      This batches matrix multiply calculations to avoid unnecessary broadcasting.
      Args:
        x: Tensor with shape [batch_size, heads, length or 1, length or depth].
        y: Tensor with shape [batch_size, heads, length or 1, depth].
        z: Tensor with shape [length or 1, length, depth].
        transpose: Whether to transpose inner matrices of y and z. Should be true if
            last dimension of x is depth, not length.
      Returns:
        A Tensor with shape [batch_size, heads, length, length or depth].
      """
      batch_size = tf.shape(x)[0]
      heads = x.get_shape().as_list()[1]
      length = tf.shape(x)[2]
    
      # xy_matmul is [batch_size, heads, length or 1, length or depth]
      xy_matmul = tf.matmul(x, y, transpose_b=transpose)
      # x_t is [length or 1, batch_size, heads, length or depth]
      x_t = tf.transpose(x, [2, 0, 1, 3])
      # x_t_r is [length or 1, batch_size * heads, length or depth]
      x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
      # x_tz_matmul is [length or 1, batch_size * heads, length or depth]
      x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
      # x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]
      x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
      # x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]
      x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
      return xy_matmul + x_tz_matmul_r_t
    
    
    def dot_product_attention_relative(q,
                                       k,
                                       v,
                                       bias,
                                       max_relative_position,
                                       dropout_rate=0.0,
                                       image_shapes=None,
                                       save_weights_to=None,
                                       name=None,
                                       make_image_summary=True,
                                       cache=False,
                                       allow_memory=False,
                                       hard_attention_k=0,
                                       gumbel_noise_weight=0.0):
      """Calculate relative position-aware dot-product self-attention.
      The attention calculation is augmented with learned representations for the
      relative position between each element in q and each element in k and v.
      Args:
        q: a Tensor with shape [batch, heads, length, depth].
        k: a Tensor with shape [batch, heads, length, depth].
        v: a Tensor with shape [batch, heads, length, depth].
        bias: bias Tensor.
        max_relative_position: an integer specifying the maximum distance between
            inputs that unique position embeddings should be learned for.
        dropout_rate: a floating point number.
        image_shapes: optional tuple of integer scalars.
        save_weights_to: an optional dictionary to capture attention weights
          for visualization; the weights tensor will be appended there under
          a string key created from the variable scope (including name).
        name: an optional string.
        make_image_summary: Whether to make an attention image summary.
        cache: whether use cache mode
        allow_memory: whether to assume that recurrent memory is in use. If True,
          the length dimension of k/v/bias may be longer than the queries, and it is
          assumed that the extra memory entries precede the non-memory entries.
        hard_attention_k: integer, if > 0 triggers hard attention (picking top-k)
        gumbel_noise_weight: if > 0, apply Gumbel noise with weight
          `gumbel_noise_weight` before picking top-k. This is a no op if
          hard_attention_k <= 0.
      Returns:
        A Tensor.
      Raises:
        ValueError: if max_relative_position is not > 0.
      """
      if not max_relative_position:
        raise ValueError("Max relative position (%s) should be > 0 when using "
                         "relative self attention." % (max_relative_position))
      with tf.variable_scope(
          name, default_name="dot_product_attention_relative",
          values=[q, k, v]) as scope:
    
        # This calculation only works for self attention.
        # q, k and v must therefore have the same shape, unless memory is enabled.
        if not cache and not allow_memory:
          q.get_shape().assert_is_compatible_with(k.get_shape())
          q.get_shape().assert_is_compatible_with(v.get_shape())
    
        # Use separate embeddings suitable for keys and values.
        depth = k.get_shape().as_list()[3]
        length_k = common_layers.shape_list(k)[2]
        length_q = common_layers.shape_list(q)[2] if allow_memory else length_k
        relations_keys = _generate_relative_positions_embeddings(
            length_q, length_k, depth, max_relative_position,
            "relative_positions_keys", cache=cache)
        relations_values = _generate_relative_positions_embeddings(
            length_q, length_k, depth, max_relative_position,
            "relative_positions_values", cache=cache)
    
        # Compute self attention considering the relative position embeddings.
        logits = _relative_attention_inner(q, k, relations_keys, True)
        if bias is not None:
          logits += bias
        weights = tf.nn.softmax(logits, name="attention_weights")
        if hard_attention_k > 0:
          weights = harden_attention_weights(weights, hard_attention_k,
                                             gumbel_noise_weight)
        if save_weights_to is not None:
          save_weights_to[scope.name] = weights
          save_weights_to[scope.name + "/logits"] = logits
        weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
        if (not tf.get_variable_scope().reuse and
            common_layers.should_generate_summaries() and
            make_image_summary):
          attention_image_summary(weights, image_shapes)
        return _relative_attention_inner(weights, v, relations_values, False)
    

    which is coresponding of the formula clip(x; k) = max(-k; min(k; x))

    but in youre code ,there is a randn with grad,i don't understand ,could you make a explanations?

    opened by AncientRemember 0
  • Is it possible to modify these codes to support 3D images as well?

    Is it possible to modify these codes to support 3D images as well?

    Thank you for your great work!

    I was wondering if it is possible to modify these codes to support 3D images as well (i.e. adding z-axis). image

    I can't imagine how to change the dimensions of vectors in the "content-position" part. E.g. Hx1xd and 1xWxd -> Hx1x1xd and 1x1xZxd and 1xWx1xd ?

    Thank you for your answer!

    opened by kyuchoi 0
  • Hello, in the training of the following mistakes, how to solve it

    Hello, in the training of the following mistakes, how to solve it

    einops.EinopsError: Error while processing rearrange-reduction pattern "b h (x y) d -> b h x y d". Input tensor shape: torch.Size([2, 4, 900, 128]). Additional info: {'x': 26, 'y': 26}. Shape mismatch, 900 != 676

    opened by glt999 1
  • the size of tenor a (9) must match the size of tenor b (10) at a non singleton dimension 3

    the size of tenor a (9) must match the size of tenor b (10) at a non singleton dimension 3

    Hello, I want to ask a question, the input feature map is 228 * 304, but here is an error, the size of tenor a (9) must match the size of tenor B (10) at a non singleton dimension 3.

    opened by shezhi 1
  • the 2d relative position embedding is not Inductive,maybe the FLOATER embedding is better

    the 2d relative position embedding is not Inductive,maybe the FLOATER embedding is better

    opened by AncientRemember 2
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
The PyTorch implementation for paper "Neural Texture Extraction and Distribution for Controllable Person Image Synthesis" (CVPR2022 Oral)

ArXiv | Get Start Neural-Texture-Extraction-Distribution The PyTorch implementation for our paper "Neural Texture Extraction and Distribution for Cont

Ren Yurui 111 Dec 10, 2022
In this project, we develop a face recognize platform based on MTCNN object-detection netcwork and FaceNet self-supervised network.

模式识别大作业——人脸检测与识别平台 本项目是一个简易的人脸检测识别平台,提供了人脸信息录入和人脸识别的功能。前端采用 html+css+js,后端采用 pytorch,

Xuhua Huang 5 Aug 02, 2022
Awesome Weak-Shot Learning

Awesome Weak-Shot Learning In weak-shot learning, all categories are split into non-overlapped base categories and novel categories, in which base cat

BCMI 162 Dec 30, 2022
A tight inclusion function for continuous collision detection

Tight-Inclusion Continuous Collision Detection A conservative Continuous Collision Detection (CCD) method with support for minimum separation. You can

Continuous Collision Detection 89 Jan 01, 2023
Voxel Transformer for 3D object detection

Voxel Transformer This is a reproduced repo of Voxel Transformer for 3D object detection. The code is mainly based on OpenPCDet. Introduction We provi

173 Dec 25, 2022
A cross-lingual COVID-19 fake news dataset

CrossFake An English-Chinese COVID-19 fake&real news dataset from the ICDMW 2021 paper below: Cross-lingual COVID-19 Fake News Detection. Jiangshu Du,

Yingtong Dou 11 Dec 01, 2022
N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting

N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting Recent progress in neural forecasting instigated significant improvements in the

Cristian Challu 82 Jan 04, 2023
An interactive DNN Model deployed on web that predicts the chance of heart failure for a patient with an accuracy of 98%

Heart Failure Predictor About A Web UI deployed Dense Neural Network Model Made using Tensorflow that predicts whether the patient is healthy or has c

Adit Ahmedabadi 0 Jan 09, 2022
an Evolutionary Algorithm assisted GAN

EvoGAN an Evolutionary Algorithm assisted GAN ckpts

3 Oct 09, 2022
DALL-Eval: Probing the Reasoning Skills and Social Biases of Text-to-Image Generative Transformers

DALL-Eval: Probing the Reasoning Skills and Social Biases of Text-to-Image Generative Transformers Authors: Jaemin Cho, Abhay Zala, and Mohit Bansal (

Jaemin Cho 98 Dec 15, 2022
Inference code for "StylePeople: A Generative Model of Fullbody Human Avatars" paper. This code is for the part of the paper describing video-based avatars.

NeuralTextures This is repository with inference code for paper "StylePeople: A Generative Model of Fullbody Human Avatars" (CVPR21). This code is for

Visual Understanding Lab @ Samsung AI Center Moscow 18 Oct 06, 2022
i3DMM: Deep Implicit 3D Morphable Model of Human Heads

i3DMM: Deep Implicit 3D Morphable Model of Human Heads CVPR 2021 (Oral) Arxiv | Poject Page This project is the official implementation our work, i3DM

Tarun Yenamandra 60 Jan 03, 2023
Simply enable or disable your Nvidia dGPU

EnvyControl (WIP) Simply enable or disable your Nvidia dGPU Usage First clone this repo and install envycontrol with sudo pip install . CLI Turn off y

Victor Bayas 292 Jan 03, 2023
Development Kit for the SoccerNet Challenge

SoccerNetv2-DevKit Welcome to the SoccerNet-V2 Development Kit for the SoccerNet Benchmark and Challenge. This kit is meant as a help to get started w

Silvio Giancola 117 Dec 30, 2022
ReSSL: Relational Self-Supervised Learning with Weak Augmentation

ReSSL: Relational Self-Supervised Learning with Weak Augmentation This repository contains PyTorch evaluation code, training code and pretrained model

mingkai 45 Oct 25, 2022
Python script to download the celebA-HQ dataset from google drive

download-celebA-HQ Python script to download and create the celebA-HQ dataset. WARNING from the author. I believe this script is broken since a few mo

133 Dec 21, 2022
PyTorch deep learning projects made easy.

PyTorch Template Project PyTorch deep learning project made easy. PyTorch Template Project Requirements Features Folder Structure Usage Config file fo

Victor Huang 3.8k Jan 01, 2023
Everything's Talkin': Pareidolia Face Reenactment (CVPR2021)

Everything's Talkin': Pareidolia Face Reenactment (CVPR2021) Linsen Song, Wayne Wu, Chaoyou Fu, Chen Qian, Chen Change Loy, and Ran He [Paper], [Video

71 Dec 21, 2022
The official implementation of the research paper "DAG Amendment for Inverse Control of Parametric Shapes"

DAG Amendment for Inverse Control of Parametric Shapes This repository is the official Blender implementation of the paper "DAG Amendment for Inverse

Elie Michel 157 Dec 26, 2022
Nest Protect integration for Home Assistant. This will allow you to integrate your smoke, heat, co and occupancy status real-time in HA.

Nest Protect integration for Home Assistant Custom component for Home Assistant to interact with Nest Protect devices via an undocumented and unoffici

Mick Vleeshouwer 175 Dec 29, 2022