Sinkformers: Transformers with Doubly Stochastic Attention

Overview

Code for the paper : "Sinkformers: Transformers with Doubly Stochastic Attention"

Paper

You will find our paper here.

Compat

This package has been developed and tested with python3.8. It is therefore not guaranteed to work with earlier versions of python.

Install the repository on your machine

This package can easily be installed using pip, with the following command:

pip install numpy
pip install -e .

This will install the package and all its dependencies, listed in requirements.txt.

Each command has to be executed from the root folder sinkformers. Our code is distributed in the different repositories. For each repository, we modify the architectures proposed by replacing the SoftMax attention with a Sinkhorn attention.

Defining a toy Sinkformer for which attention matrices are doubly stochastic

For this example we use a Transformer from the nlp-tutorial library and define its Sinkformer counterpart with the argument "n_it", the number of iterations in Sinkhorn's algorithm.

cd nlp-tutorial/text-classification-transformer
import torch
from model import TransformerEncoder
n_it = 1
print('1 iteration in Sinkhorn corresponds to the original Transformer: ')
transformer = TransformerEncoder(vocab_size=1000, seq_len=512, n_layers=1,  n_heads=1, n_it=n_it, print_attention=True, pad_id=-1)
inp = torch.arange(512).repeat(5, 1)
out = transformer(inp)
n_it = 5
print('5 iteration in Sinkhorn gives a Sinkformer with perfectly doubly stochastic attention matrices: ')
sinkformer = TransformerEncoder(vocab_size=1000, seq_len=512, n_layers=1,  n_heads=1, n_it=n_it, print_attention=True, pad_id=-1)
inp = torch.arange(512).repeat(5, 1)
out = sinkformer(inp)

Then go back to the root:

cd ..
cd ..

Reproducing the experiments of the paper

Comparison of the different normalizations.

python plot_normalizations.py

ModelNet 40 classification. Code adapted from this repository. First, you need to preprocess the ModelNet40 dataset available here. Unzip it and save it under model_net_40/data. Then, preferably on multiple cpus, run

cd model_net_40
python to_h5.py
python formatting.py
cd ..
mv model_net_40/data/ModelNet40_cloud.h5 set_transformer/ModelNet40_cloud.h5
cd set_transformer
mkdir ../dataset
mv ModelNet40_cloud.h5 ../dataset/ModelNet40_cloud.h5
cd ..

Then you can train a Set Sinkformer (or Set Transformer) on ModelNet 40 with

cd set_transformer
python one_expe.py
cd ..

Arguments for one_expe.py can be accessed through

cd set_transformer
python one_expe.py --help
cd ..

Results are saved in the folder set_transformer/results. You can plot the learning curves using the script set_transformer/plot_results.py. The array iterations in the script must contains the different values for n_it used when training.

Sentiment Analysis. Code adapted from this repository. You can also train a Sinkformer for Sentiment Analysis on the IMDb Dataset with the following command (the IMDb Dataset is downloaded automatically).

cd nlp-tutorial/text-classification-transformer
python one_expe.py
cd ..
cd ..

Arguments for one_expe.py can be accessed through

cd nlp-tutorial/text-classification-transformer
python one_expe.py --help
cd ..

Results are saved in the folder nlp-tutorial/text-classification-transformer/results. You can plot the learning curves using the script nlp-tutorial/text-classification-transformer/plot_results.py. The array iterations in the script must contain the different values for "n_it" used when training.

ViT Cats and Dogs classification. Code adapted from this repository. First, you can download the data set here, unzip it and save the train and test repositories at sinkformers/vit-pytorch/examples/data. Then you can run

cd vit-pytorch
python one_expe.py
cd ..

Arguments for one_expe.py can be accessed through

cd vit-pytorch
python one_expe.py --help
cd ..

Results are saved in the folder vit-pytorch/results. You can plot the learning curves using the script vit-pytorch/plot_results.py. The array iterations in the script must contain the different values for "n_it" used when training.

ViT MNIST. The MNIST dataset will be downloaded automatically.

cd vit-pytorch
python one_expe_mnist.py
cd ..

Arguments for one_expe_mnist.py can be accessed through

cd vit-pytorch
python one_expe_mnist.py --help
cd ..

Especially, the argument "ps" is the patch size. Results are saved in the folder vit-pytorch/results_mnist. You can plot the learning curves using the script vit-pytorch/plot_results_mnist.py. The array iterations in the script must contain the different values for "n_it" used when training. The array patches_size in the script must contain the different values for "ps" used when training.

Cite

If you use this code in your project, please cite::

Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyré
Sinkformers: Transformers with Doubly Stochastic Attention
arXiv preprint arXiv:2110.11773, 2021
https://arxiv.org/abs/2110.11773
Owner
Michael E. Sander
Michael E. Sander
Implementation of "The Power of Scale for Parameter-Efficient Prompt Tuning"

Prompt-Tuning Implementation of "The Power of Scale for Parameter-Efficient Prompt Tuning" Currently, we support the following huggigface models: Bart

Andrew Zeng 36 Dec 19, 2022
Riemannian Convex Potential Maps

Modeling distributions on Riemannian manifolds is a crucial component in understanding non-Euclidean data that arises, e.g., in physics and geology. The budding approaches in this space are limited b

Facebook Research 61 Nov 28, 2022
Music Source Separation; Train & Eval & Inference piplines and pretrained models we used for 2021 ISMIR MDX Challenge.

Music Source Separation with Channel-wise Subband Phase Aware ResUnet (CWS-PResUNet) Introduction This repo contains the pretrained Music Source Separ

Lau 100 Dec 25, 2022
Current state of supervised and unsupervised depth completion methods

Awesome Depth Completion Table of Contents About Sparse-to-Dense Depth Completion Current State of Depth Completion Unsupervised VOID Benchmark Superv

224 Dec 28, 2022
tmm_fast is a lightweight package to speed up optical planar multilayer thin-film device computation.

tmm_fast tmm_fast or transfer-matrix-method_fast is a lightweight package to speed up optical planar multilayer thin-film device computation. It is es

26 Dec 11, 2022
Nsdf: A mesh SDF with just some code we can directly paste into our raymarcher

nsdf Representing SDFs of arbitrary meshes has been a bit tricky so far. Express

Jan Ivanecky 5 Feb 18, 2022
Official implement of "CAT: Cross Attention in Vision Transformer".

CAT: Cross Attention in Vision Transformer This is official implement of "CAT: Cross Attention in Vision Transformer". Abstract Since Transformer has

100 Dec 15, 2022
🙄 Difficult algorithm, Simple code.

🎉TensorFlow2.0-Examples🎉! "Talk is cheap, show me the code." ----- Linus Torvalds Created by YunYang1994 This tutorial was designed for easily divin

1.7k Dec 25, 2022
Tweesent-back - Tweesent backend uses fastAPI as the web framework

TweeSent Backend Tweesent backend. This repo uses fastAPI as the web framework.

0 Mar 26, 2022
Package for working with hypernetworks in PyTorch.

Package for working with hypernetworks in PyTorch.

Christian Henning 71 Jan 05, 2023
SysWhispers Shellcode Loader

Shhhloader Shhhloader is a SysWhispers Shellcode Loader that is currently a Work in Progress. It takes raw shellcode as input and compiles a C++ stub

icyguider 630 Jan 03, 2023
Reproduce results and replicate training fo T0 (Multitask Prompted Training Enables Zero-Shot Task Generalization)

T-Zero This repository serves primarily as codebase and instructions for training, evaluation and inference of T0. T0 is the model developed in Multit

BigScience Workshop 253 Dec 27, 2022
Image Completion with Deep Learning in TensorFlow

Image Completion with Deep Learning in TensorFlow See my blog post for more details and usage instructions. This repository implements Raymond Yeh and

Brandon Amos 1.3k Dec 23, 2022
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
DPC: Unsupervised Deep Point Correspondence via Cross and Self Construction (3DV 2021)

DPC: Unsupervised Deep Point Correspondence via Cross and Self Construction (3DV 2021) This repo is the implementation of DPC. Tested environment Pyth

Dvir Ginzburg 30 Nov 30, 2022
Self-Supervised Pillar Motion Learning for Autonomous Driving (CVPR 2021)

Self-Supervised Pillar Motion Learning for Autonomous Driving Chenxu Luo, Xiaodong Yang, Alan Yuille Self-Supervised Pillar Motion Learning for Autono

QCraft 101 Dec 05, 2022
Train robotic agents to learn pick and place with deep learning for vision-based manipulation in PyBullet.

Ravens is a collection of simulated tasks in PyBullet for learning vision-based robotic manipulation, with emphasis on pick and place. It features a Gym-like API with 10 tabletop rearrangement tasks,

Google Research 367 Jan 09, 2023
GNN-based Recommendation Benchmark

GRecX A Fair Benchmark for GNN-based Recommendation Homepage and Documentation Homepage: Documentation: Paper: GRecX: An Efficient and Unified Benchma

73 Oct 17, 2022
TensorFlow GNN is a library to build Graph Neural Networks on the TensorFlow platform.

TensorFlow GNN This is an early (alpha) release to get community feedback. It's under active development and we may break API compatibility in the fut

889 Dec 30, 2022