Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Overview

Lambda Networks - Pytorch

Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately.

Yannic Kilcher's paper review

Install

$ pip install lambda-networks

Usage

Global context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,       # channels going in
    dim_out = 32,   # channels out
    n = 64,         # size of the receptive window - max(height, width)
    dim_k = 16,     # key dimension
    heads = 4,      # number of heads, for multi-query
    dim_u = 1       # 'intra-depth' dimension
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

Localized context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,
    dim_out = 32,
    r = 23,         # the receptive field for relative positional encoding (23 x 23)
    dim_k = 16,
    heads = 4,
    dim_u = 4
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

For fun, you can also import this as follows

from lambda_networks import λLayer

Tensorflow / Keras version

Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under ./lambda_networks/tfkeras.py or make sure to install tensorflow and keras before running the following.

import tensorflow as tf
from lambda_networks.tfkeras import LambdaLayer

layer = LambdaLayer(
    dim_out = 32,
    r = 23,
    dim_k = 16,
    heads = 4,
    dim_u = 1
)

x = tf.random.normal((1, 64, 64, 16)) # channel last format
layer(x) # (1, 64, 64, 32)

Citations

@inproceedings{
    anonymous2021lambdanetworks,
    title={LambdaNetworks: Modeling long-range Interactions without Attention},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=xTJEN-ggl1b},
    note={under review}
}
Comments
  • Contiguity problem:

    Contiguity problem: "RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input."

    It seems that LambaLayer breaks contiguity when I try it.

    layer(x).is_contiguous()
    >> False
    

    I have to use .contiguous() where I train with it, is it normal?

    opened by Whiax 5
  • Warning: Mixed memory format inputs detected while calling the operator.

    Warning: Mixed memory format inputs detected while calling the operator.

    I have added lambda layers to every block of Resnet, but the following warning will appear. Will it affect the result?

    Warning: Mixed memory format inputs detected while calling the operator. The operator will output channels_last tensor even if some of the inputs are not in channels_last format. (function operator())

    opened by Pluto1314 4
  • Implementation of Lambda convolution

    Implementation of Lambda convolution

    Thanks for the great work in the implementations!

    I would like to ask whether there is a difference in using 'Conv2d' as suggested in Eq. 3 in the paper and your implementation of 'Conv3d'. These two convs treat the (h x w)-dimension as a 1-d sequence and a 2-d image, respectively. I believe they are quite different in concept.

    Point out if I misunderstood.

    Thx a lot.

    opened by romulus0914 3
  • How lambda layer handle the downsample in LambdaResNet?

    How lambda layer handle the downsample in LambdaResNet?

    Hi, Thanks for your clear code, i try to implement the LambdaResNet. Does lambda layer replace all conv2d layer? If so, how does lambda layer handle the downsample in conv2d, like stride=2? Or just keep the conv2d if stride =2, replace only the conv2d layers in stride =1?

    opened by qiaoran-dawnlight 2
  • question about hybrid lambdaResnet

    question about hybrid lambdaResnet

    Hi,

    In the paper, there is this paragraph:

    When working with hybrid LambdaNetworks, we use a single lambda layer in c4 for LambdaResNet50, 3 lambda layers for LambdaResNet101, 6 lambda layers for LambdaResNet-152/200/270/350 and 8 lambda layers for LambdaResNet-420.

    I have several questions about constructing the hybrid lambdaResnet:

    1. Do we only need to replace the 3x3conv with lambda layer in the C4 stage rather than C4 and C5(as in the ablation study)?
    2. When there is more than 1 lambda layers, such as the case of LambdaResNet101, are we replacing the 3x3conv with 3 lambda layers? And in the resnet50 case, we replace the 3x3conv with 1 lambda layers ?
    opened by CoinCheung 2
  • Question: Is there an easy way to visualise lambdas?

    Question: Is there an easy way to visualise lambdas?

    I want to train classifier and tell what regions it pays the most.. well.. attention to :) And make this simultaneously with an inference without using gradcam etc Can I do this?

    opened by lebionick 1
  • Fix relative positional attention (position lambda)

    Fix relative positional attention (position lambda)

    Hi lucidrains, Thanks to your nice work.

    I found an error in the position lambda (relative positional attention) implementation. Relative positional attention, λp, should be translation equivariance, as written in the paper Sec. 3.2. It means that the positional embedding has a constraint, E[n, m] = E[t(n), t(m)], but it is missed in current implementation. This PR fixes it by adding the translation equivariance constraint. I checked that this PR improves the result in my experiment.

    NOTE that this PR modify the function parameter n, from total area (n=w*h) to length of each side (n=w=h).

    opened by khanrc 1
  • Use of Keras Lambda

    Use of Keras Lambda

    Hey! Thank for the awesome implementations :D

    I was wondering why the use of tf.keras.layers.Lambda? Seems unnecessary, regular calls to TF operations works and is more readable.

    https://github.com/lucidrains/lambda-networks/blob/06a48f2a5b41f3cd278aee67838c32051a0a9bed/lambda_networks/tfkeras.py#L73

    You can also call the functional version of the softmax instead.

    opened by cgarciae 1
  • Lambda for a sequence of images

    Lambda for a sequence of images

    Thanks for the quick implementation!

    I have a problem where I have a sequence of images rather than 1. (a video) So instead of having a dimension batch, channels, height, width, I also have after batch a length dimension to determine sequence length.

    Given a known max_length (for the positional embedding), in forward, should conv4d be used instead of 3d to allow interaction between frames?

    In the paper they do mention this could serve as a general framework for sequences of images, so I wonder if you explored that in implementation (where obviously a single image is just a case where length=1)

    opened by AmitMY 1
  • why flops so high

    why flops so high

    I used resnet50, and change C4 layer into LambaBottleNeck; but why flops so high about 20G and input size is 224*244; is that right, or something wrong about my inplementation.

    opened by zisuina 0
  • How to load_model correctly

    How to load_model correctly

    Hi everyone, I am struggling when loading this model saved in .h5 file.

    How is the correct way to load this network? If I use custom_objects I get init() got an unexpected keyword argument 'name'

    opened by JNaranjo-Alcazar 0
  • Please add clarity to code

    Please add clarity to code

    so Phil - I love your work - I wish you could go extra few steps to help out users. I found this class by François-Guillaume @frgfm - which adds in clear math coments. I want to merge it but there's a bit of code drift don't want to introduce any bugs. I beseech you to go extra step to help users bridge from papers to code.

    https://github.com/frgfm/Holocron/blob/bcc3ea19a477e4b28dc5973cdbe92a9b05c690bb/holocron/nn/modules/lambda_layer.py

    eg. please articulate return types def forward(self, x: torch.Tensor) -> torch.Tensor:

    Please give any clarity in arguments. # Project input and context to get queries, keys & values

    Throw in some maths as a comment / this is great as it bridges the paper to the code.

    B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)

    import torch
    from torch import nn, einsum
    import torch.nn.functional as F
    from typing import Optional
    
    __all__ = ['LambdaLayer']
    
    
    class LambdaLayer(nn.Module):
        """Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention"
        <https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains'
        <https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`.
        Args:
            in_channels (int): input channels
            out_channels (int, optional): output channels
            dim_k (int): key dimension
            n (int, optional): number of input pixels
            r (int, optional): receptive field for relative positional encoding
            num_heads (int, optional): number of attention heads
            dim_u (int, optional): intra-depth dimension
        """
        def __init__(
            self,
            in_channels: int,
            out_channels: int,
            dim_k: int,
            n: Optional[int] = None,
            r: Optional[int] = None,
            num_heads: int = 4,
            dim_u: int = 1
        ) -> None:
            super().__init__()
            self.u = dim_u
            self.num_heads = num_heads
    
            if out_channels % num_heads != 0:
                raise AssertionError('values dimension must be divisible by number of heads for multi-head query')
            dim_v = out_channels // num_heads
    
            # Project input and context to get queries, keys & values
            self.to_q = nn.Conv2d(in_channels, dim_k * num_heads, 1, bias=False)
            self.to_k = nn.Conv2d(in_channels, dim_k * dim_u, 1, bias=False)
            self.to_v = nn.Conv2d(in_channels, dim_v * dim_u, 1, bias=False)
    
            self.norm_q = nn.BatchNorm2d(dim_k * num_heads)
            self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
    
            self.local_contexts = r is not None
            if r is not None:
                if r % 2 != 1:
                    raise AssertionError('Receptive kernel size should be odd')
                self.padding = r // 2
                self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r))
            else:
                if n is None:
                    raise AssertionError('You must specify the total sequence length (h x w)')
                self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            b, c, h, w = x.shape
    
            # Project inputs & context to retrieve queries, keys and values
            q = self.to_q(x)
            k = self.to_k(x)
            v = self.to_v(x)
    
            # Normalize queries & values
            q = self.norm_q(q)
            v = self.norm_v(v)
    
            # B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)
            q = q.reshape(b, self.num_heads, -1, h * w)
            # B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W)
            k = k.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
            # B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W)
            v = v.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
    
            # Normalized keys
            k = k.softmax(dim=-1)
    
            # Content function
            λc = einsum('b u k m, b u v m -> b k v', k, v)
            Yc = einsum('b h k n, b k v -> b n h v', q, λc)
    
            # Position function
            if self.local_contexts:
                # B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x W
                v = v.reshape(b, self.u, v.shape[2], h, w)
                λp = F.conv3d(v, self.R, padding=(0, self.padding, self.padding))
                Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
            else:
                λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
                Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
    
            Y = Yc + Yp
            # B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x W
            out = Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w)
            return out
    
    opened by johndpope 1
  • Image Size

    Image Size

    Are non-square image blocks allowed for context? Using global context and a non-square dimensions (96, 128), I get an error on this line about dimension size.

    λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)

    opened by anklebreaker 0
  • LambdaResNet Implementation?

    LambdaResNet Implementation?

    I have been looking around and found one implementation of LambdaResNets, although there seem to be some metric performance problems and I've found wall-clock performance problems (runs ~7x slower than normal resnets).

    Do you plan on putting out a lambdaresnet model in this repository?

    opened by nollied 4
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
PyTorch wrappers for using your model in audacity!

PyTorch wrappers for using your model in audacity!

130 Dec 14, 2022
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 01, 2023
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
High-level batteries-included neural network training library for Pytorch

Pywick High-Level Training framework for Pytorch Pywick is a high-level Pytorch training framework that aims to get you up and running quickly with st

382 Dec 06, 2022
A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API

micrograd A tiny Autograd engine (with a bite! :)). Implements backpropagation (reverse-mode autodiff) over a dynamically built DAG and a small neural

Andrej 3.5k Jan 08, 2023
Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library A library of tested, GPU implementations of core structured prediction algorithms for deep learning applic

HNLP 1.1k Jan 07, 2023
3D-RETR: End-to-End Single and Multi-View3D Reconstruction with Transformers

3D-RETR: End-to-End Single and Multi-View 3D Reconstruction with Transformers (BMVC 2021) Zai Shi*, Zhao Meng*, Yiran Xing, Yunpu Ma, Roger Wattenhofe

Zai Shi 36 Dec 21, 2022
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

169 Dec 23, 2022
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
On the Variance of the Adaptive Learning Rate and Beyond

RAdam On the Variance of the Adaptive Learning Rate and Beyond We are in an early-release beta. Expect some adventures and rough edges. Table of Conte

Liyuan Liu 2.5k Dec 27, 2022
A code copied from google-research which named motion-imitation was rewrited with PyTorch

motor-system Introduction A code copied from google-research which named motion-imitation was rewrited with PyTorch. More details can get from this pr

NewEra 6 Jan 08, 2022
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
270 Dec 24, 2022