codes for "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation" (long paper of EMNLP-2022)

Overview

Scheduled Sampling Based on Decoding Steps for Neural Machine Translation (EMNLP-2021 main conference)

Contents

Overview

We propose to conduct scheduled sampling based on decoding steps instead of the original training steps. We observe that our proposal can more realistically simulate the distribution of real translation errors, thus better bridging the gap between training and inference. The paper has been accepted to the main conference of EMNLP-2021.

Background

fastText

We conduct scheduled sampling for the Transformer with a two-pass decoder. An example of pseudo-code is as follows:

# first-pass: the same as the standard Transformer decoder
first_decoder_outputs = decoder(first_decoder_inputs)

# sampling tokens between model predicitions and ground-truth tokens
second_decoder_inputs = sampling_function(first_decoder_outputs, first_decoder_inputs)

# second-pass: computing the decoder again with the above sampled tokens
second_decoder_outputs = decoder(second_decoder_inputs)

Quick to Use

Our approaches are suitable for most autoregressive-based tasks. Please try the following pseudo-codes when conducting scheduled sampling:

import torch

def sampling_function(first_decoder_outputs, first_decoder_inputs, max_seq_len, tgt_lengths)
    '''
    conduct scheduled sampling based on the index of decoded tokens 
    param first_decoder_outputs: [batch_size, seq_len, hidden_size], model prediections 
    param first_decoder_inputs: [batch_size, seq_len, hidden_size], ground-truth target tokens
    param max_seq_len: scalar, the max lengh of target sequence
    param tgt_lengths: [batch_size], the lenghs of target sequences in a mini-batch
    '''

    # indexs of decoding steps
    t = torch.range(0, max_seq_len-1)

    # differenct sampling strategy based on decoding steps
    if sampling_strategy == "exponential":
        threshold_table = exp_radix ** t  
    elif sampling_strategy == "sigmoid":
        threshold_table = sigmoid_k / (sigmoid_k + torch.exp(t / sigmoid_k ))
    elif sampling_strategy == "linear":        
        threshold_table = torch.max(epsilon, 1 - t / max_seq_len)
    else:
        ValuraiseeError("Unknown sampling_strategy %s" % sampling_strategy)

    # convert threshold_table to [batch_size, seq_len]
    threshold_table = threshold_table.unsqueeze_(0).repeat(max_seq_len, 1).tril()
    thresholds = threshold_table[tgt_lengths].view(-1, max_seq_len)
    thresholds = current_thresholds[:, :seq_len]

    # conduct sampling based on the above thresholds
    random_select_seed = torch.rand([batch_size, seq_len]) 
    second_decoder_inputs = torch.where(random_select_seed < thresholds, first_decoder_inputs, first_decoder_outputs)

    return second_decoder_inputs
    

Further Usage

Error accumulation is a common phenomenon in NLP tasks. Whenever you want to simulate the accumulation of errors, our method may come in handy. For examples:

# sampling tokens between noisy target tokens and ground-truth tokens
decoder_inputs = sampling_function(noisy_decoder_inputs, golden_decoder_inputs, max_seq_len, tgt_lengths)

# computing the decoder with the above sampled tokens
decoder_outputs = decoder(decoder_inputs)
# sampling utterences from model predictions and ground-truth utterences
contexts = sampling_function(predicted_utterences, golden_utterences, max_turns, current_turns)

model_predictions = dialogue_model(contexts, target_inputs)

Experiments

We provide scripts to reproduce the results in this paper(NMT and text summarization)

Citation

Please cite this paper if you find this repo useful.

@inproceedings{liu_ss_decoding_2021,
    title = "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation",
    author = "Liu, Yijin  and
      Meng, Fandong  and
      Chen, Yufeng  and
      Xu, Jinan  and
      Zhou, Jie",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
    year = "2021",
    address = "Online"
}

Contact

Please feel free to contact us ([email protected]) for any further questions.

Owner
Adaxry
Fast learner, eagle for new knowledge and deeper understanding
Adaxry
PyTorch Kafka Dataset: A definition of a dataset to get training data from Kafka.

PyTorch Kafka Dataset: A definition of a dataset to get training data from Kafka.

ERTIS Research Group 7 Aug 01, 2022
[CVPR2021 Oral] UP-DETR: Unsupervised Pre-training for Object Detection with Transformers

UP-DETR: Unsupervised Pre-training for Object Detection with Transformers This is the official PyTorch implementation and models for UP-DETR paper: @a

dddzg 430 Dec 23, 2022
Multi-Scale Geometric Consistency Guided Multi-View Stereo

ACMM [News] The code for ACMH is released!!! [News] The code for ACMP is released!!! About ACMM is a multi-scale geometric consistency guided multi-vi

Qingshan Xu 118 Jan 04, 2023
Learning Efficient Online 3D Bin Packing on Packing Configuration Trees

Learning Efficient Online 3D Bin Packing on Packing Configuration Trees This repository is being continuously updated, please stay tuned! Any code con

86 Dec 28, 2022
Fully Automatic Page Turning on Real Scores

Fully Automatic Page Turning on Real Scores This repository contains the corresponding code for our extended abstract Henkel F., Schwaiger S. and Widm

Florian Henkel 7 Jan 02, 2022
使用深度学习框架提取视频硬字幕;docker容器免安装深度学习库,使用本地api接口使得界面和后端识别分离;

extract-video-subtittle 使用深度学习框架提取视频硬字幕; 本地识别无需联网; CPU识别速度可观; 容器提供API接口; 运行环境 本项目运行环境非常好搭建,我做好了docker容器免安装各种深度学习包; 提供windows界面操作; 容器为CPU版本; 视频演示 https

歌者 16 Aug 06, 2022
Research Artifact of USENIX Security 2022 Paper: Automated Side Channel Analysis of Media Software with Manifold Learning

Manifold-SCA Research Artifact of USENIX Security 2022 Paper: Automated Side Channel Analysis of Media Software with Manifold Learning The repo is org

Yuanyuan Yuan 172 Dec 29, 2022
Hybrid Neural Fusion for Full-frame Video Stabilization

FuSta: Hybrid Neural Fusion for Full-frame Video Stabilization Project Page | Video | Paper | Google Colab Setup Setup environment for [Yu and Ramamoo

Yu-Lun Liu 430 Jan 04, 2023
Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation

Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation (CVPR2019) This is a pytorch implementatio

Yawei Luo 280 Jan 01, 2023
PyTorch implementation of ENet

PyTorch-ENet PyTorch (v1.1.0) implementation of ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation, ported from the lua-torc

David Silva 333 Dec 29, 2022
One Million Scenes for Autonomous Driving

ONCE Benchmark This is a reproduced benchmark for 3D object detection on the ONCE (One Million Scenes) dataset. The code is mainly based on OpenPCDet.

148 Dec 28, 2022
Convert ONNX model graph to Keras model format.

Convert ONNX model graph to Keras model format.

Grigory Malivenko 175 Dec 28, 2022
A Player for Kanye West's Stem Player. Sort of an emulator.

Stem Player Player Stem Player Player Usage Download the latest release here Optional: install ffmpeg, instructions here NOTE: DOES NOT ENABLE DOWNLOA

119 Dec 28, 2022
Mitsuba 2: A Retargetable Forward and Inverse Renderer

Mitsuba Renderer 2 Documentation Mitsuba 2 is a research-oriented rendering system written in portable C++17. It consists of a small set of core libra

Mitsuba Physically Based Renderer 2k Jan 07, 2023
Ἀνατομή is a PyTorch library to analyze representation of neural networks

Ἀνατομή is a PyTorch library to analyze representation of neural networks

Ryuichiro Hataya 50 Dec 05, 2022
Simple Text-Generator with OpenAI gpt-2 Pytorch Implementation

GPT2-Pytorch with Text-Generator Better Language Models and Their Implications Our model, called GPT-2 (a successor to GPT), was trained simply to pre

Tae-Hwan Jung 775 Jan 08, 2023
Normal Learning in Videos with Attention Prototype Network

Codes_APN Official codes of CVPR21 paper: Normal Learning in Videos with Attention Prototype Network (https://arxiv.org/abs/2108.11055) Overview of ou

11 Dec 13, 2022
A lightweight deep network for fast and accurate optical flow estimation.

FastFlowNet: A Lightweight Network for Fast Optical Flow Estimation The official PyTorch implementation of FastFlowNet (ICRA 2021). Authors: Lingtong

Tone 161 Jan 03, 2023
Faster RCNN pytorch windows

Faster-RCNN-pytorch-windows Faster RCNN implementation with pytorch for windows Open cmd, compile this comands: cd lib python setup.py build develop T

Hwa-Rang Kim 1 Nov 11, 2022
Learning High-Speed Flight in the Wild

Learning High-Speed Flight in the Wild This repo contains the code associated to the paper Learning Agile Flight in the Wild. For more information, pl

Robotics and Perception Group 391 Dec 29, 2022