pytorch implementation of Attention is all you need

Overview

A Pytorch Implementation of the Transformer: Attention Is All You Need

Our implementation is largely based on Tensorflow implementation

Requirements

Why This Project?

I'm a freshman of pytorch. So I tried to implement some projects by pytorch. Recently, I read the paper Attention is all you need and impressed by the idea. So that's it. I got similar result compared with the original tensorflow implementation.

Differences with the original paper

I don't intend to replicate the paper exactly. Rather, I aim to implement the main ideas in the paper and verify them in a SIMPLE and QUICK way. In this respect, some parts in my code are different than those in the paper. Among them are

  • I used the IWSLT 2016 de-en dataset, not the wmt dataset because the former is much smaller, and requires no special preprocessing.
  • I constructed vocabulary with words, not subwords for simplicity. Of course, you can try bpe or word-piece if you want.
  • I parameterized positional encoding. The paper used some sinusoidal formula, but Noam, one of the authors, says they both work. See the discussion in reddit
  • The paper adjusted the learning rate to global steps. I fixed the learning to a small number, 0.0001 simply because training was reasonably fast enough with the small dataset (Only a couple of hours on a single GTX 1060!!).

File description

  • hyperparams.py includes all hyper parameters that are needed.
  • prepro.py creates vocabulary files for the source and the target.
  • data_load.py contains functions regarding loading and batching data.
  • modules.py has all building blocks for encoder/decoder networks.
  • train.py has the model.
  • eval.py is for evaluation.

Training

wget -qO- https://wit3.fbk.eu/archive/2016-01//texts/de/en/de-en.tgz | tar xz; mv de-en corpora
  • STEP 2. Adjust hyper parameters in hyperparams.py if necessary.
  • STEP 3. Run prepro.py to generate vocabulary files to the preprocessed folder.
  • STEP 4. Run train.py or download pretrained weights, put it into folder './models/' and change the eval_epoch in hpyerparams.py to 18
  • STEP 5. Show loss and accuracy in tensorboard
tensorboard --logdir runs

Evaluation

  • Run eval.py.

Results

I got a BLEU score of 16.7.(tensorflow implementation 17.14) (Recollect I trained with a small dataset, limited vocabulary) Some of the evaluation results are as follows. Details are available in the results folder.

source: Ich bin nicht sicher was ich antworten soll
expected: I'm not really sure about the answer
got: I'm not sure what I'm going to answer

source: Was macht den Unterschied aus
expected: What makes his story different
got: What makes a difference

source: Vielen Dank
expected: Thank you
got: Thank you

source: Das ist ein Baum
expected: This is a tree
got: So this is a tree

Owner
Phd in SJTU
Fuwa-http - The http client implementation for the fuwa eco-system

Fuwa HTTP The HTTP client implementation for the fuwa eco-system Example import

Fuwa 2 Feb 16, 2022
Explaining Deep Neural Networks - A comparison of different CAM methods based on an insect data set

Explaining Deep Neural Networks - A comparison of different CAM methods based on an insect data set This is the repository for the Deep Learning proje

Robert Krug 3 Feb 06, 2022
SmartSim Infrastructure Library.

Home Install Documentation Slack Invite Cray Labs SmartSim SmartSim makes it easier to use common Machine Learning (ML) libraries like PyTorch and Ten

Cray Labs 139 Jan 01, 2023
Home repository for the Regularized Greedy Forest (RGF) library. It includes original implementation from the paper and multithreaded one written in C++, along with various language-specific wrappers.

Regularized Greedy Forest Regularized Greedy Forest (RGF) is a tree ensemble machine learning method described in this paper. RGF can deliver better r

RGF-team 364 Dec 28, 2022
Disentangled Face Attribute Editing via Instance-Aware Latent Space Search, accepted by IJCAI 2021.

Instance-Aware Latent-Space Search This is a PyTorch implementation of the following paper: Disentangled Face Attribute Editing via Instance-Aware Lat

67 Dec 21, 2022
Python library for analysis of time series data including dimensionality reduction, clustering, and Markov model estimation

deeptime Releases: Installation via conda recommended. conda install -c conda-forge deeptime pip install deeptime Documentation: deeptime-ml.github.io

495 Dec 28, 2022
U-2-Net: U Square Net - Modified for paired image training of style transfer

U2-Net: U Square Net Modified for paired image training of style transfer This is an unofficial repo making use of the code which was made available b

Doron Adler 43 Oct 03, 2022
Dynamic Capacity Networks using Tensorflow

Dynamic Capacity Networks using Tensorflow Dynamic Capacity Networks (DCN; http://arxiv.org/abs/1511.07838) implementation using Tensorflow. DCN reduc

Taeksoo Kim 8 Feb 23, 2021
Official code repository for the publication "Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons"

Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons This repository contains the code to repr

Computational Neuroscience, University of Bern 3 Aug 04, 2022
Systemic Evolutionary Chemical Space Exploration for Drug Discovery

SECSE SECSE: Systemic Evolutionary Chemical Space Explorer Chemical space exploration is a major task of the hit-finding process during the pursuit of

64 Dec 16, 2022
PyTorch-based framework for Deep Hedging

PFHedge: Deep Hedging in PyTorch PFHedge is a PyTorch-based framework for Deep Hedging. PFHedge Documentation Neural Network Architecture for Efficien

139 Dec 30, 2022
SkipGNN: Predicting Molecular Interactions with Skip-Graph Networks (Scientific Reports)

SkipGNN: Predicting Molecular Interactions with Skip-Graph Networks Molecular interaction networks are powerful resources for the discovery. While dee

Kexin Huang 49 Oct 15, 2022
Mall-Customers-Segmentation - Customer Segmentation Using K-Means Clustering

Overview Customer Segmentation is one the most important applications of unsupervised learning. Using clustering techniques, companies can identify th

NelakurthiSudheer 2 Jan 03, 2022
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

23 Nov 11, 2022
Saeed Lotfi 28 Dec 12, 2022
Towards Debiasing NLU Models from Unknown Biases

Towards Debiasing NLU Models from Unknown Biases Abstract: NLU models often exploit biased features to achieve high dataset-specific performance witho

Ubiquitous Knowledge Processing Lab 22 Jun 14, 2022
Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy

Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy Simplex Algorithm is a popular algorithm for linear programmi

Reda BELHAJ 2 Oct 12, 2022
FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI

FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI 声明: 本项目仅限于学习交流,不可用于非法用途,包括但不限于:用于游戏外挂等,使用本项目产生的任何后果与本人无关! 简介 本项目基于yolov5,实现了一款FPS类游戏(CF、CSGO等)的自瞄AI,本项目旨在使用现

Fabian 246 Dec 28, 2022
A small demonstration of using WebDataset with ImageNet and PyTorch Lightning

A small demonstration of using WebDataset with ImageNet and PyTorch Lightning This is a small repo illustrating how to use WebDataset on ImageNet. usi

50 Dec 16, 2022
Physics-informed Neural Operator for Learning Partial Differential Equation

PINO Physics-informed Neural Operator for Learning Partial Differential Equation Abstract: Machine learning methods have recently shown promise in sol

107 Jan 02, 2023