The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Overview

Dice Loss for NLP Tasks

This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020.

Setup

  • Install Package Dependencies

The code was tested in Python 3.6.9+ and Pytorch 1.7.1. If you are working on ubuntu GPU machine with CUDA 10.1, please run the following command to setup environment.

$ virtualenv -p /usr/bin/python3.6 venv
$ source venv/bin/activate
$ pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
$ pip install -r requirements.txt
  • Download BERT Model Checkpoints

Before running the repo you must download the BERT-Base and BERT-Large checkpoints from here and unzip it to some directory $BERT_DIR. Then convert original TensorFlow checkpoints for BERT to a PyTorch saved file by running bash scripts/prepare_ckpt.sh <path-to-unzip-tf-bert-checkpoints>.

Apply Dice-Loss to NLP Tasks

In this repository, we apply dice loss to four NLP tasks, including

  1. machine reading comprehension
  2. paraphrase identification task
  3. named entity recognition
  4. text classification

1. Machine Reading Comprehension

Datasets

We take SQuAD 1.1 as an example. Before training, you should download a copy of the data from here.
And move the SQuAD 1.1 train train-v1.1.json and dev file dev-v1.1.json to the directory $DATA_DIR.

Train

We choose BERT as the backbone. During training, the task trainer BertForQA will automatically evaluate on dev set every $val_check_interval epoch, and save the dev predictions into files called $OUTPUT_DIR/predictions_<train-epoch>_<total-train-step>.json and $OUTPUT_DIR/nbest_predictions_<train-epoch>_<total-train-step>.json.

Run scripts/squad1/bert_<model-scale>_<loss-type>.sh to reproduce our experimental results.
The variable <model-scale> should take the value of [base, large].
The variable <loss-type> should take the value of [bce, focal, dice] which denotes fine-tuning BERT-Base with binary cross entropy loss, focal loss, dice loss , respectively.

  • Run bash scripts/squad1/bert_base_focal.sh to start training. After training, run bash scripts/squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR for focal loss.

  • Run bash scripts/squad1/bert_base_dice.sh to start training. After training, run bash scripts/squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR for dice loss.

Evaluate

To evaluate a model checkpoint, please run

python3 tasks/squad/evaluate_models.py \
--gpus="1" \
--path_to_model_checkpoint  $OUTPUT_DIR/epoch=2.ckpt \
--eval_batch_size <evaluate-batch-size>

After evaluation, prediction results predictions_dev.json and nbest_predictions_dev.json can be found in $OUTPUT_DIR

To evaluate saved predictions, please run

python3 tasks/squad/evaluate_predictions.py <path-to-dev-v1.1.json> <directory-to-prediction-files>

2. Paraphrase Identification Task

Datasets

We use MRPC (GLUE Version) as an example. Before running experiments, you should download and save the processed dataset files to $DATA_DIR.

Run bash scripts/prepare_mrpc_data.sh $DATA_DIR to download and process datasets for MPRC (GLUE Version) task.

Train

Please run scripts/glue_mrpc/bert_<model-scale>_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. After training, the task trainer evaluates on the test set with the best checkpoint which achieves the highest F1-score on the dev set.
The variable <model-scale> should take the value of [base, large].
The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

  • Run bash scripts/glue_mrpc/bert_large_focal.sh for focal loss.

  • Run bash scripts/glue_mrpc/bert_large_dice.sh for dice loss.

The evaluation results on the dev and test set are saved at $OUTPUT_DIR/eval_result_log.txt file.
The intermediate model checkpoints are saved at most $max_keep_ckpt times.

Evaluate

To evaluate a model checkpoint on test set, please run

bash scripts/glue_mrpc/eval.sh \
$OUTPUT_DIR \
epoch=*.ckpt

3. Named Entity Recognition

For NER, we use MRC-NER model as the backbone.
Processed datasets and model architecture can be found here.

Train

Please run scripts/<ner-datdaset-name>/bert_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. After training, the task trainer evaluates on the test set with the best checkpoint.
The variable <ner-dataset-name> should take the value of [ner_enontonotes5, ner_zhmsra, ner_zhonto4].
The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

For Chinese MSRA,

  • Run scripts/ner_zhmsra/bert_focal.sh for focal loss.

  • Run scripts/ner_zhmsra/bert_dice.sh for dice loss.

For Chinese OntoNotes4,

  • Run scripts/ner_zhonto4/bert_focal.sh for focal loss.

  • Run scripts/ner_zhonto4/bert_dice.sh for dice loss.

For English OntoNotes5,

  • Run scripts/ner_enontonotes5/bert_focal.sh. After training, you will get 91.12 Span-F1 on the test set.

  • Run scripts/ner_enontonotes5/bert_dice.sh. After training, you will get 92.01 Span-F1 on the test set.

Evaluate

To evaluate a model checkpoint, please run

CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/tasks/mrc_ner/evaluate.py \
--gpus="1" \
--path_to_model_checkpoint $OUTPUT_DIR/epoch=2.ckpt

4. Text Classification

Datasets

We use TNews (Chinese Text Classification) as an example. Before running experiments, you should download and save the processed dataset files to $DATA_DIR.

Train

We choose BERT as the backbone.
Please run scripts/tnews/bert_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

  • Run bash scripts/tnews/bert_focal.sh for focal loss.

  • Run bash scripts/tnews/bert_dice.sh for dice loss.

The intermediate model checkpoints are saved at most $max_keep_ckpt times.

Citation

If you find this repository useful , please cite the following:

@article{li2019dice,
  title={Dice loss for data-imbalanced NLP tasks},
  author={Li, Xiaoya and Sun, Xiaofei and Meng, Yuxian and Liang, Junjun and Wu, Fei and Li, Jiwei},
  journal={arXiv preprint arXiv:1911.02855},
  year={2019}
}

Contact

xiaoyalixy AT gmail.com OR xiaoya_li AT shannonai.com

Any discussions, suggestions and questions are welcome!

Flask101 - FullStack Web Development with Python & JS - From TAQWA

Task: Create a CLI Calculator Step 0: Creating Virtual Environment $ python -m

Hossain Foysal 1 May 31, 2022
Learning Temporal Consistency for Low Light Video Enhancement from Single Images (CVPR2021)

StableLLVE This is a Pytorch implementation of "Learning Temporal Consistency for Low Light Video Enhancement from Single Images" in CVPR 2021, by Fan

99 Dec 19, 2022
DetCo: Unsupervised Contrastive Learning for Object Detection

DetCo: Unsupervised Contrastive Learning for Object Detection arxiv link News Sparse RCNN+DetCo improves from 45.0 AP to 46.5 AP(+1.5) with 3x+ms trai

Enze Xie 234 Dec 18, 2022
Cross-Document Coreference Resolution

Cross-Document Coreference Resolution This repository contains code and models for end-to-end cross-document coreference resolution, as decribed in ou

Arie Cattan 29 Nov 28, 2022
GeneDisco is a benchmark suite for evaluating active learning algorithms for experimental design in drug discovery.

GeneDisco is a benchmark suite for evaluating active learning algorithms for experimental design in drug discovery.

22 Dec 12, 2022
This is the source code of the solver used to compete in the International Timetabling Competition 2019.

ITC2019 Solver This is the source code of the solver used to compete in the International Timetabling Competition 2019. Building .NET Core (2.1 or hig

Edon Gashi 8 Jan 22, 2022
University of Rochester 2021 Summer REU focusing on music sentiment transfer using CycleGAN

Music-Sentiment-Transfer University of Rochester 2021 Summer REU focusing on music sentiment transfer using CycleGAN Poster: Music Sentiment Transfer

Miles Sigel 2 Jan 24, 2022
A framework for using LSTMs to detect anomalies in multivariate time series data. Includes spacecraft anomaly data and experiments from the Mars Science Laboratory and SMAP missions.

Telemanom (v2.0) v2.0 updates: Vectorized operations via numpy Object-oriented restructure, improved organization Merge branches into single branch fo

Kyle Hundman 844 Dec 28, 2022
A benchmark dataset for mesh multi-label-classification based on cube engravings introduced in MeshCNN

Double Cube Engravings This script creates a dataset for multi-label mesh clasification, with an intentionally difficult setup for point cloud classif

Yotam Erel 1 Nov 30, 2021
This repository contains FEDOT - an open-source framework for automated modeling and machine learning (AutoML)

package tests docs license stats support This repository contains FEDOT - an open-source framework for automated modeling and machine learning (AutoML

National Center for Cognitive Research of ITMO University 482 Dec 26, 2022
This repository contains the code needed to train Mega-NeRF models and generate the sparse voxel octrees

Mega-NeRF This repository contains the code needed to train Mega-NeRF models and generate the sparse voxel octrees used by the Mega-NeRF-Dynamic viewe

cmusatyalab 260 Dec 28, 2022
这是一个yolox-pytorch的源码,可以用于训练自己的模型。

YOLOX:You Only Look Once目标检测模型在Pytorch当中的实现 目录 性能情况 Performance 实现的内容 Achievement 所需环境 Environment 小技巧的设置 TricksSet 文件下载 Download 训练步骤 How2train 预测步骤

Bubbliiiing 613 Jan 05, 2023
Benchmarks for semi-supervised domain generalization.

Semi-Supervised Domain Generalization This code is the official implementation of the following paper: Semi-Supervised Domain Generalization with Stoc

Kaiyang 49 Dec 10, 2022
PyTorch implementation of Munchausen Reinforcement Learning based on DQN and SAC. Handles discrete and continuous action spaces

Exploring Munchausen Reinforcement Learning This is the project repository of my team in the "Advanced Deep Learning for Robotics" course at TUM. Our

Mohamed Amine Ketata 10 Mar 10, 2022
A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generative Modeling" (ICCV 2021)

Manifold Matching via Deep Metric Learning for Generative Modeling A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generat

69 Dec 10, 2022
Lingvo is a framework for building neural networks in Tensorflow, particularly sequence models.

Lingvo is a framework for building neural networks in Tensorflow, particularly sequence models.

2.7k Jan 05, 2023
PyTorch implementation for the Neuro-Symbolic Sudoku Solver leveraging the power of Neural Logic Machines (NLM)

Neuro-Symbolic Sudoku Solver PyTorch implementation for the Neuro-Symbolic Sudoku Solver leveraging the power of Neural Logic Machines (NLM). Please n

Ashutosh Hathidara 60 Dec 10, 2022
Retinal vessel segmentation based on GT-UNet

Retinal vessel segmentation based on GT-UNet Introduction This project is a retinal blood vessel segmentation code based on UNet-like Group Transforme

Kent0n 27 Dec 18, 2022
Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models"

Introduction Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models". In this work, we demonstrate that existi

Wei-Cheng Tseng 7 Nov 01, 2022
Adversarial Learning for Semi-supervised Semantic Segmentation, BMVC 2018

Adversarial Learning for Semi-supervised Semantic Segmentation This repo is the pytorch implementation of the following paper: Adversarial Learning fo

Wayne Hung 464 Dec 19, 2022