Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision
Training Efficiency
We show the training efficiency of our DSLP model based on vanilla NAT model. Specifically, we compared the BLUE socres of vanilla NAT and vanilla NAT with DSLP & Mixed Training on the same traning time (in hours).
As we observed, our DSLP model achieves much higher BLUE scores shortly after the training started (~3 hours). It shows that our DSLP is much more efficient in training, as our model ahieves higher BLUE scores with the same amount of training cost.
We run the experiments with 8 Tesla V100 GPUs. The batch size is 128K tokens, and each model is trained with 300K updates.
Replication
We provide the scripts of replicating the results on WMT'14 EN-DE task.
Dataset
We download the distilled data from FairSeq
Preprocessed by
TEXT=wmt14_ende_distill
python3 fairseq_cli/preprocess.py --source-lang en --target-lang de \
--trainpref $TEXT/train.en-de --validpref $TEXT/valid.en-de --testpref $TEXT/test.en-de \
--destdir data-bin/wmt14.en-de_kd --workers 40 --joined-dictionary
Training:
GLAT with DSLP
python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de --save-dir checkpoints --eval-tokenized-bleu \
--keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
--eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
--eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5 --fixed-validation-seed 7 --ddp-backend=no_c10d \
--share-all-embeddings --decoder-learned-pos --encoder-learned-pos --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \
--lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
--fp16 --clip-norm 2.0 --max-update 300000 --task translation_glat --criterion glat_loss --arch glat_sd --noise full_mask \
--src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1 --concat-yhat --concat-dropout 0.0 --label-smoothing 0.1 \
--activation-fn gelu --dropout 0.1 --max-tokens 8192 --glat-mode glat
CMLM with DSLP
python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de --save-dir checkpoints --eval-tokenized-bleu \
--keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
--eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
--eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5 --fixed-validation-seed 7 --ddp-backend=no_c10d \
--share-all-embeddings --decoder-learned-pos --encoder-learned-pos --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \
--lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
--fp16 --clip-norm 2.0 --max-update 300000 --task translation_lev --criterion nat_loss --arch glat_sd --noise full_mask \
--src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1 --concat-yhat --concat-dropout 0.0 --label-smoothing 0.1 \
--activation-fn gelu --dropout 0.1 --max-tokens 8192
Vanilla NAT with DSLP
python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de --save-dir checkpoints --eval-tokenized-bleu \
--keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
--eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
--eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5 --fixed-validation-seed 7 --ddp-backend=no_c10d \
--share-all-embeddings --decoder-learned-pos --encoder-learned-pos --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \
--lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
--fp16 --clip-norm 2.0 --max-update 300000 --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \
--src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1 --concat-yhat --concat-dropout 0.0 --label-smoothing 0.1 \
--activation-fn gelu --dropout 0.1 --max-tokens 8192
Vanilla NAT with DSLP and Mixed Training:
python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de --save-dir checkpoints --eval-tokenized-bleu \
--keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
--eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
--eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5 --fixed-validation-seed 7 --ddp-backend=no_c10d \
--share-all-embeddings --decoder-learned-pos --encoder-learned-pos --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \
--lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
--fp16 --clip-norm 2.0 --max-update 300000 --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \
--src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1 --concat-yhat --concat-dropout 0.0 --label-smoothing 0.1 \
--activation-fn gelu --dropout 0.1 --max-tokens 8192 --ss-ratio 0.3 --fixed-ss-ratio --masked-loss
CTC with DSLP:
python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de --save-dir checkpoints --eval-tokenized-bleu \
--keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
--eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
--eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5 --fixed-validation-seed 7 --ddp-backend=no_c10d \
--share-all-embeddings --decoder-learned-pos --encoder-learned-pos --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \
--lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
--fp16 --clip-norm 2.0 --max-update 300000 --task translation_lev --criterion nat_loss --arch nat_ctc_sd --noise full_mask \
--src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1 --concat-yhat --concat-dropout 0.0 \
--activation-fn gelu --dropout 0.1 --max-tokens 8192
CTC with DSLP and Mixed Training:
python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de --save-dir checkpoints --eval-tokenized-bleu \
--keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
--eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
--eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5 --fixed-validation-seed 7 --ddp-backend=no_c10d \
--share-all-embeddings --decoder-learned-pos --encoder-learned-pos --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \
--lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
--fp16 --clip-norm 2.0 --max-update 300000 --task translation_lev --criterion nat_loss --arch nat_ctc_sd_ss --noise full_mask \
--src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1 --concat-yhat --concat-dropout 0.0 \
--activation-fn gelu --dropout 0.1 --max-tokens 8192 --ss-ratio 0.3 --fixed-ss-ratio
Evaluation
fairseq-generate data-bin/wmt14.en-de_kd --path PATH_TO_A_CHECKPOINT \
--gen-subset test --task translation_lev --iter-decode-max-iter 0 \
--iter-decode-eos-penalty 0 --beam 1 --remove-bpe --print-step --batch-size 100
Note: 1) Add --plain-ctc --model-overrides '{"ctc_beam_size": 1, "plain_ctc": True}'
if it is CTC based; 2) Change the task to translation_glat
if it is GLAT based.
Output
We in addition provide the output of CTC w/ DSLP, CTC w/ DSLP & Mixed Training, Vanilla NAT w/ DSLP, Vanilla NAT w/ DSLP with Mixed Training, GLAT w/ DSLP, and CMLM w/ DSLP for review purpose.
Model | Reference | Hypothesis |
---|---|---|
CTC w/ DSLP | ref | hyp |
CTC w/ DSLP & Mixed Training | ref | hyp |
Vanilla NAT w/ DSLP | ref | hyp |
Vanilla NAT w/ DSLP & Mixed Training | ref | hyp |
GLAT w/ DSLP | ref | hyp |
CMLM w/ DSLP | ref | hyp |
Note: The output is on WMT'14 EN-DE. The references are paired with hypotheses for each model.