Implementation of a Transformer, but completely in Triton

Overview

Transformer in Triton (wip)

Implementation of a Transformer, but completely in Triton. I'm completely new to lower-level neural net code, so this repository will mostly be a learning experience, with the end-goal being a vanilla transformer that is faster and more efficient to train.

Install

$ pip install triton-transformer

Usage

import torch
from triton_transformer import Transformer

model = Transformer(
    num_tokens = 256,
    max_seq_len = 1024,
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

x = torch.randint(0, 256, (1, 1024))
mask = torch.ones(1, 1024).bool()

logits = model(x, mask = mask) # (1, 1024, 256)

Citations

@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need}, 
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
A Pytorch implementation of CVPR 2021 paper "RSG: A Simple but Effective Module for Learning Imbalanced Datasets"

RSG: A Simple but Effective Module for Learning Imbalanced Datasets (CVPR 2021) A Pytorch implementation of our CVPR 2021 paper "RSG: A Simple but Eff

A concise but complete implementation of CLIP with various experimental improvements from recent papers
A concise but complete implementation of CLIP with various experimental improvements from recent papers

x-clip (wip) A concise but complete implementation of CLIP with various experimental improvements from recent papers Install $ pip install x-clip Usag

A concise but complete implementation of CLIP with various experimental improvements from recent papers
A concise but complete implementation of CLIP with various experimental improvements from recent papers

x-clip (wip) A concise but complete implementation of CLIP with various experimental improvements from recent papers Install $ pip install x-clip Usag

Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)
Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)

Protein GLM (wip) Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capabil

Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation. E.T. is based on a multimodal transformer that encodes language inputs and the full episode history of visual observations and actions. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation "

nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation ". Please

3D-Transformer: Molecular Representation with Transformer in 3D Space

3D-Transformer: Molecular Representation with Transformer in 3D Space

Comments
  • Question concerning PyTorch build

    Question concerning PyTorch build

    Hello. I find your project very interesting and I have seen your comparison between PyTorch and Triton implementations.

    However, I am curious whether your PyTorch environment is a source build optimized for your machine or a pip/conda install.

    Source building has faster runtimes and if a conda install is being used for comparison, the difference in speed may simply be due to Triton optimizing CUDA for the run environment.

    Thank you again for your interesting project.

    opened by veritas9872 13
  • _layernorm implementation forward result not equal F.layer_norm

    _layernorm implementation forward result not equal F.layer_norm

    I have a try on your triton-transformer and test the layernorm module alone. It's very weird that the forward result is different while the backward result is equal.

    code: from triton_transformer.layernorm import layernorm import torch import torch.nn as nn

    torch.manual_seed(0) x = torch.randn(2,5).cuda() x.requires_grad_(True) dy = .1*torch.randn_like(x).cuda() dim = 5 norm = nn.LayerNorm(dim).cuda()

    y1 = layernorm(x, norm.weight, norm.bias, use_triton = True) y2 = layernorm(x, norm.weight, norm.bias, use_triton = False) print(y1, y2) print(torch.allclose(y1, y2))

    y1.backward(dy, retain_graph=True) dx_y1 = x.grad.clone()

    x.grad = None

    y2.backward(dy, retain_graph=True) dx_y2 = x.grad.clone() print(dx_y1, dx_y2) print(torch.allclose(dx_y1, dx_y2))

    result: `tensor([[ 0.9492, -0.0021, -0.9797, 0.4449, -0.4123], [-0.7624, 0.4399, 0.7299, -0.3091, -0.0983]], device='cuda:0', grad_fn=<_layernormBackward>) tensor([[ 1.4217, -0.0031, -1.4674, 0.6663, -0.6175], [-1.4342, 0.8276, 1.3732, -0.5815, -0.1850]], device='cuda:0', grad_fn=) False

    tensor([[-0.0706, 0.0288, -0.0813, 0.0446, 0.0785], [ 0.0218, -0.0152, 0.0141, -0.0522, 0.0315]], device='cuda:0') tensor([[-0.0706, 0.0288, -0.0813, 0.0446, 0.0785], [ 0.0218, -0.0152, 0.0141, -0.0522, 0.0315]], device='cuda:0') True`

    opened by Tengxu-Sun 1
  • Current state of benchmarking & contributing?

    Current state of benchmarking & contributing?

    Hey @lucidrains - hope you're doing well! I have some time to hack the next couple weeks, just wanted to get a sense of:

    • Current state of benchmarking (what Triton kernels provide how much lift, aggregate lift over a "vanilla Transformer implementation"
    • If there's anything I could help with, especially as I learn Triton!
    opened by siddk 0
  • Official layer norm added

    Official layer norm added

    Hi @lucidrains , in Triton layer norm was just added in examples, https://github.com/openai/triton/commit/d4baad426db72b83c5222e1c83c929c1860cae54 I tested it, it's twice as fast as Torch, often faster then Apex.

    I'm looking forward for your implementation of attention, so far the Torch implementation is the fastest with 12.3 / 14.5 (forw / back) vs the other Triton implementation in DeepSpeed which is 17.3/ 23.0 on my data.

    opened by olegklimov 2
Releases(0.1.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Recurrent Scale Approximation (RSA) for Object Detection

Recurrent Scale Approximation (RSA) for Object Detection Codebase for Recurrent Scale Approximation for Object Detection in CNN published at ICCV 2017

Yu Liu (Louis) 239 Dec 28, 2022
Official pytorch implementation of Active Learning for deep object detection via probabilistic modeling (ICCV 2021)

Active Learning for Deep Object Detection via Probabilistic Modeling This repository is the official PyTorch implementation of Active Learning for Dee

NVIDIA Research Projects 130 Jan 06, 2023
Official implementation of the PICASO: Permutation-Invariant Cascaded Attentional Set Operator

PICASO Official PyTorch implemetation for the paper PICASO:Permutation-Invariant Cascaded Attentive Set Operator. Requirements Python 3 torch = 1.0 n

Samira Zare 0 Dec 23, 2021
This is the repository for paper NEEDLE: Towards Non-invertible Backdoor Attack to Deep Learning Models.

This is the repository for paper NEEDLE: Towards Non-invertible Backdoor Attack to Deep Learning Models.

1 Oct 25, 2021
A repository for the updated version of CoinRun used to collect MUGEN, a multimodal video-audio-text dataset.

A repository for the updated version of CoinRun used to collect MUGEN, a multimodal video-audio-text dataset. This repo contains scripts to train RL agents to navigate the closed world and collect vi

MUGEN 11 Oct 22, 2022
Improved Fitness Optimization Landscapes for Sequence Design

ReLSO Improved Fitness Optimization Landscapes for Sequence Design Description Citation How to run Training models Original data source Description In

Krishnaswamy Lab 44 Dec 20, 2022
Python Algorithm Interview Book Review

파이썬 알고리즘 인터뷰 책 리뷰 리뷰 IT 대기업에 들어가고 싶은 목표가 있다. 내가 꿈꿔온 회사에서 일하는 사람들의 모습을 보면 멋있다고 생각이 들고 나의 목표에 대한 열망이 강해지는 것 같다. 미래의 핵심 사업 중 하나인 SW 부분을 이끌고 발전시키는 우리나라의 I

SharkBSJ 1 Dec 14, 2021
MassiveSumm: a very large-scale, very multilingual, news summarisation dataset

MassiveSumm: a very large-scale, very multilingual, news summarisation dataset This repository contains links to data and code to fetch and reproduce

Daniel Varab 19 Dec 16, 2022
Pytorch implementation of the paper: "A Unified Framework for Separating Superimposed Images", in CVPR 2020.

Deep Adversarial Decomposition PDF | Supp | 1min-DemoVideo Pytorch implementation of the paper: "Deep Adversarial Decomposition: A Unified Framework f

Zhengxia Zou 72 Dec 18, 2022
Warning: This project does not have any current developer. See bellow.

Pylearn2: A machine learning research library Warning : This project does not have any current developer. We will continue to review pull requests and

Laboratoire d’Informatique des Systèmes Adaptatifs 2.7k Dec 26, 2022
Official implementation of "MetaSDF: Meta-learning Signed Distance Functions"

MetaSDF: Meta-learning Signed Distance Functions Project Page | Paper | Data Vincent Sitzmann*, Eric Ryan Chan*, Richard Tucker, Noah Snavely Gordon W

Vincent Sitzmann 100 Jan 01, 2023
Image De-raining Using a Conditional Generative Adversarial Network

Image De-raining Using a Conditional Generative Adversarial Network [Paper Link] [Project Page] He Zhang, Vishwanath Sindagi, Vishal M. Patel In this

He Zhang 216 Dec 18, 2022
Myia prototyping

Myia Myia is a new differentiable programming language. It aims to support large scale high performance computations (e.g. linear algebra) and their g

Mila 456 Nov 07, 2022
GLANet - The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv

GLANet The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv Framework: visualization results: Getting Starte

stanley 29 Dec 14, 2022
A pytorch implementation of MBNET: MOS PREDICTION FOR SYNTHESIZED SPEECH WITH MEAN-BIAS NETWORK

Pytorch-MBNet A pytorch implementation of MBNET: MOS PREDICTION FOR SYNTHESIZED SPEECH WITH MEAN-BIAS NETWORK Training To train a new model, please ru

46 Dec 28, 2022
A framework for the elicitation, specification, formalization and understanding of requirements.

A framework for the elicitation, specification, formalization and understanding of requirements.

NASA - Software V&V 161 Jan 03, 2023
Official implementation of SIGIR'2021 paper: "Sequential Recommendation with Graph Neural Networks".

SURGE: Sequential Recommendation with Graph Neural Networks This is our TensorFlow implementation for the paper: Sequential Recommendation with Graph

FIB LAB, Tsinghua University 53 Dec 26, 2022
Learning Time-Critical Responses for Interactive Character Control

Learning Time-Critical Responses for Interactive Character Control Abstract This code implements the paper Learning Time-Critical Responses for Intera

Movement Research Lab 227 Dec 31, 2022
Python tools for 3D face: 3DMM, Mesh processing(transform, camera, light, render), 3D face representations.

face3d: Python tools for processing 3D face Introduction This project implements some basic functions related to 3D faces. You can use this to process

Yao Feng 2.3k Dec 30, 2022
mPose3D, a mmWave-based 3D human pose estimation model.

mPose3D, a mmWave-based 3D human pose estimation model.

KylinChen 35 Nov 08, 2022