Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) - PyTorch Implementation

Overview

Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding

PyTorch implementation for the Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) paper.

Method Description

Distilled Sentence Embedding (DSE) distills knowledge from a finetuned state-of-the-art transformer model (BERT) to create high quality sentence embeddings. For a complete description, as well as implementation details and hyperparameters, please refer to the paper.

Usage

Follow the instructions below in order to run the training procedure of the Distilled Sentence Embedding (DSE) method. The python scripts below can be run with the -h parameter to get more information.

1. Install Requirements

Tested with Python 3.6+.

pip install -r requirements.txt

2. Download GLUE Datasets

Run the download_glue_data.py script to download the GLUE datasets.

python download_glue_data.py

3. Finetune BERT on a Specific Task

Finetune a standard BERT model on a specific task (e.g., MRPC, MNLI, etc.). Below is an example for the MRPC dataset.

python finetune_bert.py \
--bert_model bert-large-uncased-whole-word-masking \
--task_name mrpc \
--do_train \
--do_eval \
--do_lower_case \
--data_dir glue_data/MRPC \
--max_seq_length 128 \
--train_batch_size 32 \
--gradient_accumulation_steps 2 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir outputs/large_uncased_finetuned_mrpc \
--overwrite_output_dir \
--no_parallel

Note: For our code to work with the AllNLI dataset (a combination of the MNLI and SNLI datasets), you simply need to create a folder where the downloaded GLUE datasets are and copy the MNLI and SNLI datasets into it.

4. Create Logits for Distillation from the Finetuned BERT

Execute the following command to create the logits which will be used for the distillation training objective. Note that the bert_checkpoint_dir parameter has to match the output_dir from the previous command.

python run_distillation_logits_creator.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--do_lower_case \
--train_features_path glue_data/MRPC/train_bert-large-uncased-whole-word-masking_128_mrpc \
--bert_checkpoint_dir outputs/large_uncased_finetuned_mrpc

5. Train the DSE Model using the Finetuned BERT Logits

Train the DSE model using the extracted logits. Notice that the distillation_logits_path parameter needs to be changed according to the task.

python dse_train_runner.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--distillation_logits_path outputs/logits/large_uncased_finetuned_mrpc_logits.pt \
--do_lower_case \
--file_log \
--epochs 8 \
--store_checkpoints \
--fc_dims 512 1

Important Notes:

  • To store checkpoints for the model make sure that the --store_checkpoints flag is passed as shown above.
  • The fc_dims parameter accepts a list of space separated integers, and is the dimensions of the fully connected classifier that is put on top of the extracted features from the Siamese DSE model. The output dimension (in this case 1) needs to be changed according to the wanted output dimensionality. For example, for the MNLI dataset the fc_dims parameter should be 512 3 since it is a 3 class classification task.

6. Loading the Trained DSE Model

During training, checkpoints of the Trainer object which contains the model will be saved. You can load these checkpoints and extract the model state dictionary from them. Then you can load the state into a created DSESiameseClassifier model. The load_dse_checkpoint_example.py script contains an example of how to do that.

To load the model trained with the example commands above you can use:

python load_dse_checkpoint_example.py \
--task_name mrpc \
--trainer_checkpoint <path_to_saved_checkpoint> \
--do_lower_case \
--fc_dims 512 1

Acknowledgments

Citation

If you find this code useful, please cite the following paper:

@inproceedings{barkan2020scalable,
  title={Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding},
  author={Barkan, Oren and Razin, Noam and Malkiel, Itzik and Katz, Ori and Caciularu, Avi and Koenigstein, Noam},
  booktitle={AAAI Conference on Artificial Intelligence (AAAI)},
  year={2020}
}
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
使用深度学习框架提取视频硬字幕;docker容器免安装深度学习库,使用本地api接口使得界面和后端识别分离;

extract-video-subtittle 使用深度学习框架提取视频硬字幕; 本地识别无需联网; CPU识别速度可观; 容器提供API接口; 运行环境 本项目运行环境非常好搭建,我做好了docker容器免安装各种深度学习包; 提供windows界面操作; 容器为CPU版本; 视频演示 https

歌者 16 Aug 06, 2022
Python based framework for Automatic AI for Regression and Classification over numerical data.

Python based framework for Automatic AI for Regression and Classification over numerical data. Performs model search, hyper-parameter tuning, and high-quality Jupyter Notebook code generation.

BlobCity, Inc 141 Dec 21, 2022
Unofficial implementation of "TTNet: Real-time temporal and spatial video analysis of table tennis" (CVPR 2020)

TTNet-Pytorch The implementation for the paper "TTNet: Real-time temporal and spatial video analysis of table tennis" An introduction of the project c

Nguyen Mau Dung 438 Dec 29, 2022
CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation

CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation (CVPR 2021, oral presentation) CoCosNet v2: Full-Resolution Correspondence

Microsoft 308 Dec 07, 2022
Jittor implementation of PCT:Point Cloud Transformer

PCT: Point Cloud Transformer This is a Jittor implementation of PCT: Point Cloud Transformer.

MenghaoGuo 547 Jan 03, 2023
A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation

##A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation. #USAGE To run the trained classifier on some images: python w

Alex Seewald 13 Nov 17, 2022
A vanilla 3D face modeling on pose-invariant and multi-lightning image data

3D-Face-Modeling A vanilla 3D face modeling on pose-invariant and multi-lightning image data Table of Contents Background Install Usage Contributing B

Haochen Zhang 1 Mar 12, 2022
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

Karbo 45 Dec 21, 2022
A PyTorch Implementation of "Neural Arithmetic Logic Units"

Neural Arithmetic Logic Units [WIP] This is a PyTorch implementation of Neural Arithmetic Logic Units by Andrew Trask, Felix Hill, Scott Reed, Jack Ra

Kevin Zakka 181 Nov 18, 2022
HyDiff: Hybrid Differential Software Analysis

HyDiff: Hybrid Differential Software Analysis This repository provides the tool and the evaluation subjects for the paper HyDiff: Hybrid Differential

Yannic Noller 22 Oct 20, 2022
Julia package for multiway (inverse) covariance estimation.

TensorGraphicalModels TensorGraphicalModels.jl is a suite of Julia tools for estimating high-dimensional multiway (tensor-variate) covariance and inve

Wayne Wang 3 Sep 23, 2022
Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

33 Dec 01, 2022
Weakly-Supervised Semantic Segmentation Network with Deep Seeded Region Growing (CVPR 2018).

Weakly-Supervised Semantic Segmentation Network with Deep Seeded Region Growing (CVPR2018) By Zilong Huang, Xinggang Wang, Jiasi Wang, Wenyu Liu and J

Zilong Huang 245 Dec 13, 2022
A more easy-to-use implementation of KPConv

A more easy-to-use implementation of KPConv This repo contains a more easy-to-use implementation of KPConv based on PyTorch. Introduction KPConv is a

Zheng Qin 35 Dec 14, 2022
PyTorch implementation of ECCV 2020 paper "Foley Music: Learning to Generate Music from Videos "

Foley Music: Learning to Generate Music from Videos This repo holds the code for the framework presented on ECCV 2020. Foley Music: Learning to Genera

Chuang Gan 30 Nov 03, 2022
Black box hyperparameter optimization made easy.

BBopt BBopt aims to provide the easiest hyperparameter optimization you'll ever do. Think of BBopt like Keras (back when Theano was still a thing) for

Evan Hubinger 70 Nov 03, 2022
Improving Factual Consistency of Abstractive Text Summarization

Improving Factual Consistency of Abstractive Text Summarization We provide the code for the papers: "Entity-level Factual Consistency of Abstractive T

61 Nov 27, 2022
Source code for PairNorm (ICLR 2020)

PairNorm Official pytorch source code for PairNorm paper (ICLR 2020) This code requires pytorch_geometric=1.3.2 usage For SGC, we use original PairNo

62 Dec 08, 2022
Code for "Single-view robot pose and joint angle estimation via render & compare", CVPR 2021 (Oral).

Single-view robot pose and joint angle estimation via render & compare Yann Labbé, Justin Carpentier, Mathieu Aubry, Josef Sivic CVPR: Conference on C

Yann Labbé 51 Oct 14, 2022
Official implementation of "Intrinsic Dimension, Persistent Homology and Generalization in Neural Networks", NeurIPS 2021.

PHDimGeneralization Official implementation of "Intrinsic Dimension, Persistent Homology and Generalization in Neural Networks", NeurIPS 2021. Overvie

Tolga Birdal 13 Nov 08, 2022