This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Related tags

Deep LearningVDA
Overview

Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models

This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Quick Links

Overview

We propose a general framework Virtual Data Augmentation (VDA) for robustly fine-tuning Pre-trained Language Models for downstream tasks. Our VDA utilizes a masked language model with Gaussian noise to augment virtual examples for improving the robustness, and also adopts regularized training to further guarantee the semantic relevance and diversity.

Train VDA

In the following section, we describe how to train a model with VDA by using our code.

Training

Data

For evaluation of our VDA, we use 6 text classification datasets, i.e. Yelp, IMDB, AGNews, MR, QNLI and MRPC datasets. These datasets can be downloaded from the GoogleDisk

After download the two ziped files, users should unzip the data fold that contains the training, validation and test data of the 6 datasets. While the Robust fold contains the examples for test the robustness.

Training scripts We public our VDA with 4 base models. For single sentence classification tasks, we use text_classifier_xxx.py files. While for sentence pair classification tasks, we use text_pair_classifier_xxx.py:

  • text_classifier.py and text_pair_classifier.py: BERT-base+VDA

  • text_classifier_freelb.py and text_pair_classifier_freelb.py: FreeLB+VDA on BERT-base

  • text_classifier_smart.py and text_pair_classifier_smart.py: SMART+VDA on BERT-base, where we only use the smooth-inducing adversarial regularization.

  • text_classifier_smix.py and text_pair_classifier_smix.py: Smix+VDA on BERT-base, where we remove the adversarial data augmentation for fair comparison

We provide example scripts for both training and test of our VDA on the 6 datasets. In run_train.sh, we provide 6 example for training on the yelp and qnli datasets. This script calls text_classifier_xxx.py for training (xxx refers to the base model). We explain the arguments in following:

  • --dataset: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --save_path: Saved fine-tuned checkpoints file.
  • --max_length: Max sequence length. (For Yelp/IMDB/AG, we use 512. While for MR/QNLI/MRPC, we use 256.)
  • --max_epoch: The maximum training epoch number. (In most of datasets and models, we use 10.)
  • --batch_size: The batch size. (We adapt the batch size to the maximum number w.r.t the GPU memory size. Note that too small number may cause model collapse.)
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)
  • --lr: Learning rate.
  • --num_warmup: The rate of warm-up steps.
  • --variance: The variance of the Gaussian noise.

For results in the paper, we use Nvidia Tesla V100 32G and Nvidia 3090 24G GPUs to train our models. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.

Evaluation

During training, our model file will show the original accuracy on the test set of the 6 datasets, which evaluates the accuracy performance of our model. Our evaluation code for robustness is based on a modified version of BERT-Attack. It outputs Attack Accuracy, Query Numbers and Perturbation Ratio metrics.

Before evaluation, please download the evaluation datasets for Robustness from the GoogleDisk. Then, following the commonly-used settings, users need to download and process consine similarity matrix following TextFooler.

Based on the checkpoint of the fine-tuned models, we use therun_test.sh script for test the robustness on yelp and qnli datasets. It is based on bert_robust.py file. We explain the arguments in following:

  • --data_path: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --tgt_path: The fine-tuned checkpoints file.
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)

which is expected to output the results as:

original accuracy is 0.960000, attack accuracy is 0.533333, query num is 687.680556, perturb rate is 0.177204

Citation

Please cite our paper if you use VDA in your work:

@inproceedings{zhou2021vda,
  author    = {Kun Zhou, Wayne Xin Zhao, Sirui Wang, Fuzheng Zhang, Wei Wu and Ji-Rong Wen},
  title     = {Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models},
  booktitle = {{EMNLP} 2021},
  publisher = {The Association for Computational Linguistics},
}
Owner
RUCAIBox
An enthusiastic group that aims to create beautiful things with AI
RUCAIBox
Can we visualize a large scientific data set with a surrogate model? We're building a GAN for the Earth's Mantle Convection data set to see if we can!

EarthGAN - Earth Mantle Surrogate Modeling Can a surrogate model of the Earth’s Mantle Convection data set be built such that it can be readily run in

Tim 0 Dec 09, 2021
Implementation of the CVPR 2021 paper "Online Multiple Object Tracking with Cross-Task Synergy"

Online Multiple Object Tracking with Cross-Task Synergy This repository is the implementation of the CVPR 2021 paper "Online Multiple Object Tracking

54 Oct 15, 2022
GRF: Learning a General Radiance Field for 3D Representation and Rendering

GRF: Learning a General Radiance Field for 3D Representation and Rendering [Paper] [Video] GRF: Learning a General Radiance Field for 3D Representatio

Alex Trevithick 243 Dec 29, 2022
The full training script for Enformer (Tensorflow Sonnet) on TPU clusters

Enformer TPU training script (wip) The full training script for Enformer (Tensorflow Sonnet) on TPU clusters, in an effort to migrate the model to pyt

Phil Wang 10 Oct 19, 2022
Python code for loading the Aschaffenburg Pose Dataset.

Aschaffenburg Pose Dataset (APD) This repository contains Python code for loading and filtering the Aschaffenburg Pose Dataset. The dataset itself and

1 Nov 26, 2021
Large dataset storage format for Pytorch

H5Record Large dataset ( 100G, = 1T) storage format for Pytorch (wip) Support python 3 pip install h5record Why? Writing large dataset is still a

theblackcat102 43 Oct 22, 2022
9th place solution

AllDataAreExt-Galixir-Kaggle-HPA-2021-Solution Team Members Qishen Ha is Master of Engineering from the University of Tokyo. Machine Learning Engineer

daishu 5 Nov 18, 2021
A code implementation of AC-GC: Activation Compression with Guaranteed Convergence, in NeurIPS 2021.

Code For AC-GC: Lossy Activation Compression with Guaranteed Convergence This code is intended to be used as a supplemental material for submission to

Dave Evans 2 Nov 01, 2022
Use unsupervised and supervised learning to predict stocks

AIAlpha: Multilayer neural network architecture for stock return prediction This project is meant to be an advanced implementation of stacked neural n

Vivek Palaniappan 1.5k Jan 06, 2023
Cross-Modal Contrastive Learning for Text-to-Image Generation

Cross-Modal Contrastive Learning for Text-to-Image Generation This repository hosts the open source JAX implementation of XMC-GAN. Setup instructions

Google Research 94 Nov 12, 2022
🔮 A refreshing functional take on deep learning, compatible with your favorite libraries

Thinc: A refreshing functional take on deep learning, compatible with your favorite libraries From the makers of spaCy, Prodigy and FastAPI Thinc is a

Explosion 2.6k Dec 30, 2022
Data Engineering ZoomCamp

Data Engineering ZoomCamp I'm partaking in a Data Engineering Bootcamp / Zoomcamp and will be tracking my progress here. I can't promise these notes w

Aaron 61 Jan 06, 2023
Dahua Camera and Doorbell Home Assistant Integration

Home Assistant Dahua Integration The Dahua Home Assistant integration allows you to integrate your Dahua cameras and doorbells in Home Assistant. It's

Ronnie 216 Dec 26, 2022
"Learning and Analyzing Generation Order for Undirected Sequence Models" in Findings of EMNLP, 2021

undirected-generation-dev This repo contains the source code of the models described in the following paper "Learning and Analyzing Generation Order f

Yichen Jiang 0 Mar 25, 2022
Vit-ImageClassification - Pytorch ViT for Image classification on the CIFAR10 dataset

Vit-ImageClassification Introduction This project uses ViT to perform image clas

Kaicheng Yang 4 Jun 01, 2022
Bottleneck Transformers for Visual Recognition

Bottleneck Transformers for Visual Recognition Experiments Model Params (M) Acc (%) ResNet50 baseline (ref) 23.5M 93.62 BoTNet-50 18.8M 95.11% BoTNet-

Myeongjun Kim 236 Jan 03, 2023
[NIPS 2021] UOTA: Improving Self-supervised Learning with Automated Unsupervised Outlier Arbitration.

UOTA: Improving Self-supervised Learning with Automated Unsupervised Outlier Arbitration This repository is the official PyTorch implementation of UOT

6 Jun 29, 2022
Multi-Task Deep Neural Networks for Natural Language Understanding

New Release We released Adversarial training for both LM pre-training/finetuning and f-divergence. Large-scale Adversarial training for LMs: ALUM code

Xiaodong 2.1k Dec 30, 2022
QHack—the quantum machine learning hackathon

Official repo for QHack—the quantum machine learning hackathon

Xanadu 72 Dec 21, 2022
Fast Neural Representations for Direct Volume Rendering

Fast Neural Representations for Direct Volume Rendering Sebastian Weiss, Philipp Hermüller, Rüdiger Westermann This repository contains the code and s

Sebastian Weiss 20 Dec 03, 2022