Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Overview

Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Reference

  • Paper URL

  • Author: Yi Tay, Dara Bahri, Donald Metzler, Da-Cheng Juan, Zhe Zhao, Che Zheng

  • Google Research

Method

model

1. Dense Synthesizer

2. Fixed Random Synthesizer

3. Random Synthesizer

4. Factorized Dense Synthesizer

5. Factorized Random Synthesizer

6. Mixture of Synthesizers

Usage

import torch

from synthesizer import Transformer, SynthesizerDense, SynthesizerRandom, FactorizedSynthesizerDense, FactorizedSynthesizerRandom, MixtureSynthesizers, get_n_params, calculate_flops


def main():
    batch_size, channel_dim, sentence_length = 2, 1024, 32
    x = torch.randn([batch_size, sentence_length, channel_dim])

    vanilla = Transformer(channel_dim)
    out, attention_map = vanilla(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(vanilla), calculate_flops(vanilla.children())
    print('vanilla, n_params: {}, flops: {}'.format(n_params, flops))

    dense_synthesizer = SynthesizerDense(channel_dim, sentence_length)
    out, attention_map = dense_synthesizer(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(dense_synthesizer), calculate_flops(dense_synthesizer.children())
    print('dense_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))

    random_synthesizer = SynthesizerRandom(channel_dim, sentence_length)
    out, attention_map = random_synthesizer(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(random_synthesizer), calculate_flops(random_synthesizer.children())
    print('random_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))

    random_synthesizer_fix = SynthesizerRandom(channel_dim, sentence_length, fixed=True)
    out, attention_map = random_synthesizer_fix(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(random_synthesizer_fix), calculate_flops(random_synthesizer_fix.children())
    print('random_synthesizer_fix, n_params: {}, flops: {}'.format(n_params, flops))

    factorized_synthesizer_random = FactorizedSynthesizerRandom(channel_dim)
    out, attention_map = factorized_synthesizer_random(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(factorized_synthesizer_random), calculate_flops(
        factorized_synthesizer_random.children())
    print('factorized_synthesizer_random, n_params: {}, flops: {}'.format(n_params, flops))

    factorized_synthesizer_dense = FactorizedSynthesizerDense(channel_dim, sentence_length)
    out, attention_map = factorized_synthesizer_dense(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(factorized_synthesizer_dense), calculate_flops(
        factorized_synthesizer_dense.children())
    print('factorized_synthesizer_dense, n_params: {}, flops: {}'.format(n_params, flops))

    mixture_synthesizer = MixtureSynthesizers(channel_dim, sentence_length)
    out, attention_map = mixture_synthesizer(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(mixture_synthesizer), calculate_flops(mixture_synthesizer.children())
    print('mixture_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))


if __name__ == '__main__':
    main()

Output

torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
vanilla, n_params: 3148800, flops: 3145729
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
dense_synthesizer, n_params: 1083456, flops: 1082370
torch.Size([2, 32, 1024]) torch.Size([1, 32, 32])
random_synthesizer, n_params: 1050624, flops: 1048577
torch.Size([2, 32, 1024]) torch.Size([1, 32, 32])
random_synthesizer_fix, n_params: 1050624, flops: 1048577
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
factorized_synthesizer_random, n_params: 1066000, flops: 1064961
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
factorized_synthesizer_dense, n_params: 1061900, flops: 1060865
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
mixture_synthesizer, n_params: 3149824, flops: 3145729

Paper Performance

eval

Owner
Myeongjun Kim
Computer Vision Research using Deep Learning
Myeongjun Kim
Tree LSTM implementation in PyTorch

Tree-Structured Long Short-Term Memory Networks This is a PyTorch implementation of Tree-LSTM as described in the paper Improved Semantic Representati

Riddhiman Dasgupta 529 Dec 10, 2022
Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021).

STAR-pytorch Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021). CVF (pdf) STAR-DC

43 Dec 21, 2022
DeepLab2: A TensorFlow Library for Deep Labeling

DeepLab2 is a TensorFlow library for deep labeling, aiming to provide a unified and state-of-the-art TensorFlow codebase for dense pixel labeling tasks.

Google Research 845 Jan 04, 2023
Large dataset storage format for Pytorch

H5Record Large dataset ( 100G, = 1T) storage format for Pytorch (wip) Support python 3 pip install h5record Why? Writing large dataset is still a

theblackcat102 43 Oct 22, 2022
WTTE-RNN a framework for churn and time to event prediction

WTTE-RNN Weibull Time To Event Recurrent Neural Network A less hacky machine-learning framework for churn- and time to event prediction. Forecasting p

Egil Martinsson 727 Dec 28, 2022
A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

CapsNet-Tensorflow A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules Notes: The current version

Huadong Liao 3.8k Dec 29, 2022
Galactic and gravitational dynamics in Python

Gala is a Python package for Galactic and gravitational dynamics. Documentation The documentation for Gala is hosted on Read the docs. Installation an

Adrian Price-Whelan 101 Dec 22, 2022
The official homepage of the (outdated) COCO-Stuff 10K dataset.

COCO-Stuff 10K dataset v1.1 (outdated) Holger Caesar, Jasper Uijlings, Vittorio Ferrari Overview Welcome to official homepage of the COCO-Stuff [1] da

Holger Caesar 263 Dec 11, 2022
Code to reproduce the results in the paper "Tensor Component Analysis for Interpreting the Latent Space of GANs".

Tensor Component Analysis for Interpreting the Latent Space of GANs [ paper | project page ] Code to reproduce the results in the paper "Tensor Compon

James Oldfield 4 Jun 17, 2022
Supervised Contrastive Learning for Product Matching

Contrastive Product Matching This repository contains the code and data download links to reproduce the experiments of the paper "Supervised Contrasti

Web-based Systems Group @ University of Mannheim 18 Dec 10, 2022
Simulation code and tutorial for BBHnet training data

Simulation Dataset for BBHnet NOTE: OLD README, UPDATE IN PROGRESS We generate simulation dataset to train BBHnet, our deep learning framework for det

0 May 31, 2022
Official Pytorch implementation of the paper: "Locally Shifted Attention With Early Global Integration"

Locally-Shifted-Attention-With-Early-Global-Integration Pretrained models You can download all the models from here. Training Imagenet python -m torch

Shelly Sheynin 14 Apr 15, 2022
Music source separation is a task to separate audio recordings into individual sources

Music Source Separation Music source separation is a task to separate audio recordings into individual sources. This repository is an PyTorch implmeme

Bytedance Inc. 958 Jan 03, 2023
Code for the paper A Theoretical Analysis of the Repetition Problem in Text Generation

A Theoretical Analysis of the Repetition Problem in Text Generation This repository share the code for the paper "A Theoretical Analysis of the Repeti

Zihao Fu 37 Nov 21, 2022
This program uses trial auth token of Azure Cognitive Services to do speech synthesis for you.

🗣️ aspeak A simple text-to-speech client using azure TTS API(trial). 😆 TL;DR: This program uses trial auth token of Azure Cognitive Services to do s

Levi Zim 359 Jan 05, 2023
This program creates a formatted excel file which highlights the undervalued stock according to Graham's number.

Over-and-Undervalued-Stocks Of Nepse Using Graham's Number Scrap the latest data using different websites and creates a formatted excel file that high

6 May 03, 2022
Lucid Sonic Dreams syncs GAN-generated visuals to music.

Lucid Sonic Dreams Lucid Sonic Dreams syncs GAN-generated visuals to music. By default, it uses NVLabs StyleGAN2, with pre-trained models lifted from

731 Jan 02, 2023
Code for paper "ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation"

ASAP-Net This project implements ASAP-Net of paper ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation (BMVC2020). Overview We i

Hanwen Cao 26 Aug 25, 2022
This is the workbook I created while I was studying for the Qiskit Associate Developer exam. I hope this becomes useful to others as it was for me :)

A Workbook for the Qiskit Developer Certification Exam Hello everyone! This is Bartu, a fellow Qiskitter. I have recently taken the Certification exam

Bartu Bisgin 66 Dec 10, 2022
Inference code for "StylePeople: A Generative Model of Fullbody Human Avatars" paper. This code is for the part of the paper describing video-based avatars.

NeuralTextures This is repository with inference code for paper "StylePeople: A Generative Model of Fullbody Human Avatars" (CVPR21). This code is for

Visual Understanding Lab @ Samsung AI Center Moscow 18 Oct 06, 2022