Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision

Overview

This repo contains the implementation of our paper:

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision

Paper Link

Replication

Python environment

pip install -e . # under DSLP directory
pip install tensorflow tensorboard sacremoses nltk Ninja omegaconf
pip install 'fuzzywuzzy[speedup]'
pip install hydra-core==1.0.6
pip install sacrebleu==1.5.1
pip install git+https://github.com/dugu9sword/lunanlp.git
git clone --recursive https://github.com/parlance/ctcdecode.git
cd ctcdecode && pip install .

Dataset

We downloaded the distilled data from FairSeq

Preprocessed by

TEXT=wmt14_ende_distill
python3 fairseq_cli/preprocess.py --source-lang en --target-lang de \
   --trainpref $TEXT/train.en-de --validpref $TEXT/valid.en-de --testpref $TEXT/test.en-de \
   --destdir data-bin/wmt14.en-de_kd --workers 40 --joined-dictionary

Or you can download all the binarized files here.

Hyperparameters

EN<->RO EN<->DE
--validate-interval-updates 300 500
number of tokens per batch 32K 128K
--dropout 0.3 0.1

Note:

  1. We found that label smoothing for CTC-based models are not useful (at least not with our implementation), it is suggested to keep --label-smoothing as 0 for them.
  2. Dropout rate plays a significant role for GLAT, CMLM, and the Vanilla NAT. On WMT'14 EN->De, for example, the Vanilla NAT with dropout 0.1 reaches 21.18 BLEU; but only gives 19.68 BLEU with dropout 0.3.

Training:

We provide the scripts for replicating the results on WMT'14 EN->DE task. For other tasks, you need to adapt the binary path, --source-lang, --target-lang, and some other hyperparameters accordingly.

GLAT with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_glat --criterion glat_loss --arch glat_sd --noise full_mask \ 
   --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 --glat-mode glat \ 
   --length-loss-factor 0.1 --pred-length-offset 

CMLM with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch glat_sd --noise full_mask \ 
   --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 \
   --length-loss-factor 0.1 --pred-length-offset 

Vanilla NAT with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \ 
   --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 \
   --length-loss-factor 0.1 --pred-length-offset 

Vanilla NAT with DSLP and Mixed Training:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \ 
   --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192  --ss-ratio 0.3 --fixed-ss-ratio --masked-loss \ 
   --length-loss-factor 0.1 --pred-length-offset 

CTC with DSLP:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_ctc_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.0 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 

CTC with DSLP and Mixed Training:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_ctc_sd_ss --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.0 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 --ss-ratio 0.3 --fixed-ss-ratio

Evaluation

Average the last best 5 checkpoints with scripts/average_checkpoints.py, our results are based on either the best checkpoint or the averaged checkpoint, depending on their valid set BLEU.

fairseq-generate data-bin/wmt14.en-de_kd  --path PATH_TO_A_CHECKPOINT \
    --gen-subset test --task translation_lev --iter-decode-max-iter 0 \
    --iter-decode-eos-penalty 0 --beam 1 --remove-bpe --print-step --batch-size 100

Note: 1) Add --plain-ctc --model-overrides '{"ctc_beam_size": 1, "plain_ctc": True}' if it is CTC based; 2) Change the task to translation_glat if it is GLAT based.

Output

We in addition provide the output of CTC w/ DSLP, CTC w/ DSLP & Mixed Training, Vanilla NAT w/ DSLP, Vanilla NAT w/ DSLP with Mixed Training, GLAT w/ DSLP, and CMLM w/ DSLP for review purpose.

Model Reference Hypothesis
CTC w/ DSLP ref hyp
CTC w/ DSLP & Mixed Training ref hyp
Vanilla NAT w/ DSLP ref hyp
Vanilla NAT w/ DSLP & Mixed Training ref hyp
GLAT w/ DSLP ref hyp
CMLM w/ DSLP ref hyp

Note: The output is on WMT'14 EN-DE. The references are paired with hypotheses for each model.

Training Efficiency

We show the training efficiency of our DSLP model based on vanilla NAT model. Specifically, we compared the BLUE socres of vanilla NAT and vanilla NAT with DSLP & Mixed Training on the same traning time (in hours).

As we observed, our DSLP model achieves much higher BLUE scores shortly after the training started (~3 hours). It shows that our DSLP is much more efficient in training, as our model ahieves higher BLUE scores with the same amount of training cost.

Efficiency

We run the experiments with 8 Tesla V100 GPUs. The batch size is 128K tokens, and each model is trained with 300K updates.

Owner
Chenyang Huang
Stay hungry, stay foolish
Chenyang Huang
MHtyper is an end-to-end pipeline for recognized the Forensic microhaplotypes in Nanopore sequencing data.

MHtyper is an end-to-end pipeline for recognized the Forensic microhaplotypes in Nanopore sequencing data. It is implemented using Python.

willow 6 Jun 27, 2022
Code for Discovering Topics in Long-tailed Corpora with Causal Intervention.

Code for Discovering Topics in Long-tailed Corpora with Causal Intervention ACL2021 Findings Usage 0. Prepare environment Requirements: python==3.6 te

Xiaobao Wu 8 Dec 16, 2022
Crowd sourced training data for Rasa NLU models

NLU Training Data Crowd-sourced training data for the development and testing of Rasa NLU models. If you're interested in grabbing some data feel free

Rasa 169 Dec 26, 2022
Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Phil Wang 5k Jan 02, 2023
Study German declensions (dER nettE Mann, ein nettER Mann, mit dEM nettEN Mann, ohne dEN nettEN Mann ...) Generate as many exercises as you want using the incredible power of SPACY!

Study German declensions (dER nettE Mann, ein nettER Mann, mit dEM nettEN Mann, ohne dEN nettEN Mann ...) Generate as many exercises as you want using the incredible power of SPACY!

Hans Alemão 4 Jul 20, 2022
Code and dataset for the EMNLP 2021 Finding paper "Can NLI Models Verify QA Systems’ Predictions?"

Code and dataset for the EMNLP 2021 Finding paper "Can NLI Models Verify QA Systems’ Predictions?"

Jifan Chen 22 Oct 21, 2022
Fast topic modeling platform

The state-of-the-art platform for topic modeling. Full Documentation User Mailing List Download Releases User survey What is BigARTM? BigARTM is a pow

BigARTM 633 Dec 21, 2022
jiant is an NLP toolkit

🚨 Update 🚨 : As of 2021/10/17, the jiant project is no longer being actively maintained. This means there will be no plans to add new models, tasks,

ML² AT CILVR 1.5k Dec 28, 2022
Twitter-NLP-Analysis - Twitter Natural Language Processing Analysis

Twitter-NLP-Analysis Business Problem I got last @turk_politika 3000 tweets with

Çağrı Karadeniz 7 Mar 12, 2022
Python library for Serbian Natural language processing (NLP)

SrbAI - Python biblioteka za procesiranje srpskog jezika SrbAI je projekat prikupljanja algoritama i modela za procesiranje srpskog jezika u jedinstve

Serbian AI Society 3 Nov 22, 2022
Research code for the paper "Fine-tuning wav2vec2 for speaker recognition"

Fine-tuning wav2vec2 for speaker recognition This is the code used to run the experiments in https://arxiv.org/abs/2109.15053. Detailed logs of each t

Nik 103 Dec 26, 2022
nlp基础任务

NLP算法 说明 此算法仓库包括文本分类、序列标注、关系抽取、文本匹配、文本相似度匹配这五个主流NLP任务,涉及到22个相关的模型算法。 框架结构 文件结构 all_models ├── Base_line │   ├── __init__.py │   ├── base_data_process.

zuxinqi 23 Sep 22, 2022
Official code repository of the paper Linear Transformers Are Secretly Fast Weight Programmers.

Linear Transformers Are Secretly Fast Weight Programmers This repository contains the code accompanying the paper Linear Transformers Are Secretly Fas

Imanol Schlag 77 Dec 19, 2022
FastFormers - highly efficient transformer models for NLU

FastFormers FastFormers provides a set of recipes and methods to achieve highly efficient inference of Transformer models for Natural Language Underst

Microsoft 678 Jan 05, 2023
Abhijith Neil Abraham 2 Nov 05, 2021
A python project made to generate code using either OpenAI's codex or GPT-J (Although not as good as codex)

CodeJ A python project made to generate code using either OpenAI's codex or GPT-J (Although not as good as codex) Install requirements pip install -r

TheProtagonist 1 Dec 06, 2021
LV-BERT: Exploiting Layer Variety for BERT (Findings of ACL 2021)

LV-BERT Introduction In this repo, we introduce LV-BERT by exploiting layer variety for BERT. For detailed description and experimental results, pleas

Weihao Yu 14 Aug 24, 2022
HF's ML for Audio study group

Hugging Face Machine Learning for Audio Study Group Welcome to the ML for Audio Study Group. Through a series of presentations, paper reading and disc

Vaibhav Srivastav 110 Jan 01, 2023
Задания КЕГЭ по информатике 2021 на Python

КЕГЭ 2021 на Python В этом репозитории мои решения типовых заданий КЕГЭ по информатике в 2021 году, БЕСПЛАТНО! Задания Взяты с https://inf-ege.sdamgia

8 Oct 13, 2022
An assignment from my grad-level data mining course demonstrating some experience with NLP/neural networks/Pytorch

NLP-Pytorch-Assignment An assignment from my grad-level data mining course (before I started personal projects) demonstrating some experience with NLP

David Thorne 0 Feb 06, 2022