Transformer model implemented with Pytorch

Overview

transformer-pytorch

Transformer model implemented with Pytorch

Attention is all you need-[Paper]

Architecture

Transformer


Self-Attention

Attention

self_attention.py

[N, len, heads, head_dim] values = values.reshape(N, value_len, self.heads, self.head_dim) keys = keys.reshape(N, key_len, self.heads, self.head_dim) queries = queries.reshape(N, query_len, self.heads, self.head_dim) # Einsum does matrix mult. for query*keys for each training example # with every other training example, don't be confused by einsum # it's just how I like doing matrix multiplication & bmm energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # queries shape: (N, query_len, heads, heads_dim), # keys shape: (N, key_len, heads, heads_dim) # energy: (N, heads, query_len, key_len) # Mask padded indices so their weights become 0 if mask is not None: energy = energy.masked_fill(mask == 0, float("-1e20")) # Normalize energy values similarly to seq2seq + attention # so that they sum to 1. Also divide by scaling factor for # better stability attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3) # attention shape: (N, heads, query_len, key_len) out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( N, query_len, self.heads * self.head_dim ) # attention shape: (N, heads, query_len, key_len) # values shape: (N, value_len, heads, heads_dim) # out after matrix multiply: (N, query_len, heads, head_dim), then # we reshape and flatten the last two dimensions. out = self.fc_out(out) # Linear layer doesn't modify the shape, final shape will be # (N, query_len, embed_size) return out ">
 class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads      = heads
        self.head_dim   = embed_size // heads

        assert (
                self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values  = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys    = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out  = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values  = self.values(values)
        keys    = self.keys(keys)
        queries = self.queries(query)
        
        # Split the embedding into self.heads different pieces
        # Multi head
        # [N, len, embed_size] --> [N, len, heads, head_dim]
        values    = values.reshape(N, value_len, self.heads, self.head_dim)
        keys      = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries   = queries.reshape(N, query_len, self.heads, self.head_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)

        return out

Encoder Block

Encoder

encoder_block.py

class EncoderBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(EncoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1     = nn.LayerNorm(embed_size)
        self.norm2     = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization and finally dropout
        x       = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out     = self.dropout(self.norm2(forward + x))
        return out

Encoder

Encoder

encoder.py

class Encoder(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
    ):

        super(Encoder, self).__init__()
        self.embed_size         = embed_size
        self.device             = device
        self.word_embedding     = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                EncoderBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        # In the Encoder the query, key, value are all the same, it's in the
        # decoder this will change. This might look a bit odd in this case.
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

Decoder Block

DecoderBlock

docoder_block.py

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.norm              = nn.LayerNorm(embed_size)
        self.attention         = SelfAttention(embed_size, heads=heads)
        self.transformer_block = EncoderBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout           = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query     = self.dropout(self.norm(attention + x))
        out       = self.transformer_block(value, key, query, src_mask)
        return out

Decoder

Decoder

decoder.py

class Decoder(nn.Module):
    def __init__(
            self,
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
    ):
        super(Decoder, self).__init__()
        self.device             = device
        self.word_embedding     = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        
        self.dropout = nn.Dropout(dropout)
        self.fc_out  = nn.Linear(embed_size, trg_vocab_size)


    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions     = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x             = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

Transformer

transformer.py

class Transformer(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            trg_vocab_size,
            src_pad_idx,
            trg_pad_idx,
            embed_size=512,
            num_layers=6,
            forward_expansion=4,
            heads=8,
            dropout=0,
            device="cpu",
            max_length=100,
    ):

        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device      = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )

        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

Authors

Owner
Mingu Kang
SW Engineering / ML / DL / Blockchain Dept. of Software Engineering, Jeonbuk National University
Mingu Kang
EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering

MADE (Multi-Adapter Dataset Experts) This repository contains the implementation of MADE (Multi-adapter dataset experts), which is described in the pa

Princeton Natural Language Processing 68 Jul 18, 2022
RodoSol-ALPR Dataset

RodoSol-ALPR Dataset This dataset, called RodoSol-ALPR dataset, contains 20,000 images captured by static cameras located at pay tolls owned by the Ro

Rayson Laroca 45 Dec 15, 2022
The code for our paper submitted to RAL/IROS 2022: OverlapTransformer: An Efficient and Rotation-Invariant Transformer Network for LiDAR-Based Place Recognition.

OverlapTransformer The code for our paper submitted to RAL/IROS 2022: OverlapTransformer: An Efficient and Rotation-Invariant Transformer Network for

HAOMO.AI 136 Jan 03, 2023
Unified Instance and Knowledge Alignment Pretraining for Aspect-based Sentiment Analysis

Unified Instance and Knowledge Alignment Pretraining for Aspect-based Sentiment Analysis Requirements python 3.7 pytorch-gpu 1.7 numpy 1.19.4 pytorch_

12 Oct 29, 2022
The Surprising Effectiveness of Visual Odometry Techniques for Embodied PointGoal Navigation

PointNav-VO The Surprising Effectiveness of Visual Odometry Techniques for Embodied PointGoal Navigation Project Page | Paper Table of Contents Setup

Xiaoming Zhao 41 Dec 15, 2022
Demos of essentia classifiers hosted on replicate.ai

essentia-replicate-demos Demos of Essentia models hosted on replicate.ai's MTG site. The models Check our site for a complete list of the models avail

Music Technology Group - Universitat Pompeu Fabra 12 Nov 14, 2022
This is a vision-based 3d model manipulation and control UI

Manipulation of 3D Models Using Hand Gesture This program allows user to manipulation 3D models (.obj format) with their hands. The project support bo

Cortic Technology Corp. 43 Oct 23, 2022
Patch SVDD for Image anomaly detection

Patch SVDD Patch SVDD for Image anomaly detection. Paper: https://arxiv.org/abs/2006.16067 (published in ACCV 2020). Original Code : https://github.co

Hong-Jeongmin 0 Dec 03, 2021
Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021)

Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) This repository is the official P

Jingyun Liang 159 Dec 30, 2022
Transformers are Graph Neural Networks!

🚀 Gated Graph Transformers Gated Graph Transformers for graph-level property prediction, i.e. graph classification and regression. Associated article

Chaitanya Joshi 46 Jun 30, 2022
Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation

STCN Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [a

Rex Cheng 456 Dec 12, 2022
Sound-guided Semantic Image Manipulation - Official Pytorch Code (CVPR 2022)

🔉 Sound-guided Semantic Image Manipulation (CVPR2022) Official Pytorch Implementation Sound-guided Semantic Image Manipulation IEEE/CVF Conference on

CVLAB 58 Dec 28, 2022
ByteTrack with ReID module following the paradigm of FairMOT, tracking strategy is borrowed from FairMOT/JDE.

ByteTrack_ReID ByteTrack is the SOTA tracker in MOT benchmarks with strong detector YOLOX and a simple association strategy only based on motion infor

Han GuangXin 46 Dec 29, 2022
Data and codes for ACL 2021 paper: Towards Emotional Support Dialog Systems

Emotional-Support-Conversation Copyright © 2021 CoAI Group, Tsinghua University. All rights reserved. Data and codes are for academic research use onl

126 Dec 21, 2022
Your interactive network visualizing dashboard

Your interactive network visualizing dashboard Documentation: Here What is Jaal Jaal is a python based interactive network visualizing tool built usin

Mohit 177 Jan 04, 2023
DSL for matching Python ASTs

py-ast-rule-engine This library provides a DSL (domain-specific language) to match a pattern inside a Python AST (abstract syntax tree). The library i

1 Dec 18, 2021
the official implementation of the paper "Isometric Multi-Shape Matching" (CVPR 2021)

Isometric Multi-Shape Matching (IsoMuSh) Paper-CVF | Paper-arXiv | Video | Code Citation If you find our work useful in your research, please consider

Maolin Gao 9 Jul 17, 2022
A Comprehensive Analysis of Weakly-Supervised Semantic Segmentation in Different Image Domains (IJCV submission)

wsss-analysis The code of: A Comprehensive Analysis of Weakly-Supervised Semantic Segmentation in Different Image Domains, arXiv pre-print 2019 paper.

Lyndon Chan 48 Dec 18, 2022
Image Lowpoly based on Centroid Voronoi Diagram via python-opencv and taichi

CVTLowpoly: Image Lowpoly via Centroid Voronoi Diagram Image Sharp Feature Extraction using Guide Filter's Local Linear Theory via opencv-python. The

Pupa 4 Jul 29, 2022
This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models This repository contains the code for our paper VDA (publ

RUCAIBox 13 Aug 06, 2022