This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Related tags

Deep LearningVDA
Overview

Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models

This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Quick Links

Overview

We propose a general framework Virtual Data Augmentation (VDA) for robustly fine-tuning Pre-trained Language Models for downstream tasks. Our VDA utilizes a masked language model with Gaussian noise to augment virtual examples for improving the robustness, and also adopts regularized training to further guarantee the semantic relevance and diversity.

Train VDA

In the following section, we describe how to train a model with VDA by using our code.

Training

Data

For evaluation of our VDA, we use 6 text classification datasets, i.e. Yelp, IMDB, AGNews, MR, QNLI and MRPC datasets. These datasets can be downloaded from the GoogleDisk

After download the two ziped files, users should unzip the data fold that contains the training, validation and test data of the 6 datasets. While the Robust fold contains the examples for test the robustness.

Training scripts We public our VDA with 4 base models. For single sentence classification tasks, we use text_classifier_xxx.py files. While for sentence pair classification tasks, we use text_pair_classifier_xxx.py:

  • text_classifier.py and text_pair_classifier.py: BERT-base+VDA

  • text_classifier_freelb.py and text_pair_classifier_freelb.py: FreeLB+VDA on BERT-base

  • text_classifier_smart.py and text_pair_classifier_smart.py: SMART+VDA on BERT-base, where we only use the smooth-inducing adversarial regularization.

  • text_classifier_smix.py and text_pair_classifier_smix.py: Smix+VDA on BERT-base, where we remove the adversarial data augmentation for fair comparison

We provide example scripts for both training and test of our VDA on the 6 datasets. In run_train.sh, we provide 6 example for training on the yelp and qnli datasets. This script calls text_classifier_xxx.py for training (xxx refers to the base model). We explain the arguments in following:

  • --dataset: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --save_path: Saved fine-tuned checkpoints file.
  • --max_length: Max sequence length. (For Yelp/IMDB/AG, we use 512. While for MR/QNLI/MRPC, we use 256.)
  • --max_epoch: The maximum training epoch number. (In most of datasets and models, we use 10.)
  • --batch_size: The batch size. (We adapt the batch size to the maximum number w.r.t the GPU memory size. Note that too small number may cause model collapse.)
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)
  • --lr: Learning rate.
  • --num_warmup: The rate of warm-up steps.
  • --variance: The variance of the Gaussian noise.

For results in the paper, we use Nvidia Tesla V100 32G and Nvidia 3090 24G GPUs to train our models. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.

Evaluation

During training, our model file will show the original accuracy on the test set of the 6 datasets, which evaluates the accuracy performance of our model. Our evaluation code for robustness is based on a modified version of BERT-Attack. It outputs Attack Accuracy, Query Numbers and Perturbation Ratio metrics.

Before evaluation, please download the evaluation datasets for Robustness from the GoogleDisk. Then, following the commonly-used settings, users need to download and process consine similarity matrix following TextFooler.

Based on the checkpoint of the fine-tuned models, we use therun_test.sh script for test the robustness on yelp and qnli datasets. It is based on bert_robust.py file. We explain the arguments in following:

  • --data_path: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --tgt_path: The fine-tuned checkpoints file.
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)

which is expected to output the results as:

original accuracy is 0.960000, attack accuracy is 0.533333, query num is 687.680556, perturb rate is 0.177204

Citation

Please cite our paper if you use VDA in your work:

@inproceedings{zhou2021vda,
  author    = {Kun Zhou, Wayne Xin Zhao, Sirui Wang, Fuzheng Zhang, Wei Wu and Ji-Rong Wen},
  title     = {Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models},
  booktitle = {{EMNLP} 2021},
  publisher = {The Association for Computational Linguistics},
}
Owner
RUCAIBox
An enthusiastic group that aims to create beautiful things with AI
RUCAIBox
Its a Plant Leaf Disease Detection System based on Machine Learning.

My_Project_Code Its a Plant Leaf Disease Detection System based on Machine Learning. I have used Tomato Leaves Dataset from kaggle. This system detect

Sanskriti Sidola 3 Jun 15, 2022
PyTorch implementation of the Deep SLDA method from our CVPRW-2020 paper "Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis"

Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis This is a PyTorch implementation of the Deep Streaming Linear Discriminant

Tyler Hayes 41 Dec 25, 2022
This repository contains demos I made with the Transformers library by HuggingFace.

Transformers-Tutorials Hi there! This repository contains demos I made with the Transformers library by 🤗 HuggingFace. Currently, all of them are imp

3.5k Jan 01, 2023
Process text, including tokenizing and representing sentences as vectors and Applying some concepts like RNN, LSTM and GRU to create a classifier can detect the language in which a sentence is written from among 17 languages.

Language Identifier What is this ? The goal of this project is to create a model that is able to predict a given sentence language through text proces

Hossam Asaad 9 Dec 15, 2022
Collection of generative models in Pytorch version.

pytorch-generative-model-collections Original : [Tensorflow version] Pytorch implementation of various GANs. This repository was re-implemented with r

Hyeonwoo Kang 2.4k Dec 31, 2022
render sprites into your desktop environment as shaped windows using GTK

spritegtk render static or animated sprites into your desktop environment as dynamic shaped windows using GTK requires pycairo and PYGobject: pip inst

hermit 20 Oct 27, 2022
E-Ink Magic Calendar that automatically syncs to Google Calendar and runs off a battery powered Raspberry Pi Zero

MagInkCal This repo contains the code needed to drive an E-Ink Magic Calendar that uses a battery powered (PiSugar2) Raspberry Pi Zero WH to retrieve

2.8k Dec 28, 2022
[ECCV 2020] XingGAN for Person Image Generation

Contents XingGAN or CrossingGAN Installation Dataset Preparation Generating Images Using Pretrained Model Train and Test New Models Evaluation Acknowl

Hao Tang 218 Oct 29, 2022
EgGateWayGetShell py脚本

EgGateWayGetShell_py 免责声明 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,作者不为此承担任何责任。 使用 python3 eg.py urls.txt 目标 title:锐捷网络-EWEB网管系统 port:4430 漏洞成因 ?p

榆木 61 Nov 09, 2022
Explaining in Style: Training a GAN to explain a classifier in StyleSpace

Explaining in Style: Official TensorFlow Colab Explaining in Style: Training a GAN to explain a classifier in StyleSpace Oran Lang, Yossi Gandelsman,

Google 197 Nov 08, 2022
This is an early in-development version of training CLIP models with hivemind.

A transformer that does not hog your GPU memory This is an early in-development codebase: if you want a stable and documented hivemind codebase, look

<a href=[email protected]"> 4 Nov 06, 2022
New approach to benchmark VQA models

VQA Benchmarking This repository contains the web application & the python interface to evaluate VQA models. Documentation Please see the documentatio

4 Jul 25, 2022
CNN Based Meta-Learning for Noisy Image Classification and Template Matching

CNN Based Meta-Learning for Noisy Image Classification and Template Matching Introduction This master thesis used a few-shot meta learning approach to

Kumar Manas 2 Dec 09, 2021
Learning Synthetic Environments and Reward Networks for Reinforcement Learning

Learning Synthetic Environments and Reward Networks for Reinforcement Learning We explore meta-learning agent-agnostic neural Synthetic Environments (

AutoML-Freiburg-Hannover 16 Sep 02, 2022
InterFaceGAN - Interpreting the Latent Space of GANs for Semantic Face Editing

InterFaceGAN - Interpreting the Latent Space of GANs for Semantic Face Editing Figure: High-quality facial attributes editing results with InterFaceGA

GenForce: May Generative Force Be with You 1.3k Jan 09, 2023
classification task on dataset-CIFAR10,by using Tensorflow/keras

CIFAR10-Tensorflow classification task on dataset-CIFAR10,by using Tensorflow/keras 在这一个库中,我使用Tensorflow与keras框架搭建了几个卷积神经网络模型,针对CIFAR10数据集进行了训练与测试。分别使

3 Oct 17, 2021
Deep Learning for Time Series Classification

Deep Learning for Time Series Classification This is the companion repository for our paper titled "Deep learning for time series classification: a re

Hassan ISMAIL FAWAZ 1.2k Jan 02, 2023
Python calculations for the position of the sun and moon.

Astral This is 'astral' a Python module which calculates Times for various positions of the sun: dawn, sunrise, solar noon, sunset, dusk, solar elevat

Simon Kennedy 169 Dec 20, 2022
This code is the implementation of the paper "Coherence-Based Distributed Document Representation Learning for Scientific Documents".

Introduction This code is the implementation of the paper "Coherence-Based Distributed Document Representation Learning for Scientific Documents". If

tsc 0 Jan 11, 2022
Good Semi-Supervised Learning That Requires a Bad GAN

Good Semi-Supervised Learning that Requires a Bad GAN This is the code we used in our paper Good Semi-supervised Learning that Requires a Bad GAN Ziha

Zhilin Yang 177 Dec 12, 2022