ZeroGen: Efficient Zero-shot Learning via Dataset Generation

Overview

ZEROGEN

This repository contains the code for our paper “ZeroGen: Efficient Zero-shot Learning via Dataset Generation”. Our implementation is built on the source code from dino. Thanks for their work.

If you use this code, please cite our paper:

@article{ye2022zerogen,
      title={ZeroGen: Efficient Zero-shot Learning via Dataset Generation}, 
      author={Jiacheng Ye and Jiahui Gao and Qintong Li and Hang Xu and Jiangtao Feng and Zhiyong Wu and Tao Yu and Lingpeng Kong},
      year={2022},
      eprint={2202.07922},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Setup

All requirements for ZEROGEN can be found in requirements.txt. You can install all required packages in a new environment with pip install -r requirements.txt.

Usage

The scripts/run_cls.sh and scripts/run_qa.sh scripts contain the running commands for the following settings:

  • supervised learning with human annotations (SUPERVISED)
  • prompt-based zero-shot learning (PROMPTING)
  • efficient zero-shot learning via dataset generation (ZEROGEN)

For text classification (TC) tasks (e.g., SST-2 and IMDb) and natural language inference (NLI) tasks (e.g., QNLI and RTE), run with bash scripts/run_cls.sh. For question answering (QA) tasks, run with bash scripts/run_qa.sh

When generating X (i.e., denotes text in TC, hypothesis in NLI and question in QA) in the final stage of the scripts, we also train the small model and evaluate it on human annotations. Specifically, after generating log_every number of examples, we perform training on the synthetic dataset and evaluation on the gold validation set. This gives as a trend graph similar to Figure 2 in the paper, which is shown by wandb, a powerful toolkit to track experiments.

Before running, you need to reset the following parameters to yours:

  • home_dir: path to ZeroGen
  • gpu: gpu id
  • batch_size: the batch size for generating with PLM. For SST-2, it costs ~16G when using a batch size of 32 with gpt2-xl. While for SQuAD, it costs ~60G using the same batch size and PLM because of the longer contexts. So decrease the batch size if needed.
  • WANDB_PROJECT: project name, by default ZeroGen
  • WANDB_ENTITY: your wandb username
  • WANDB_API_KEY: your api-key

By default we use GPT2-XL as pre-trained language model (PLM) and DistilBERT as tiny-task model (TAM), to modify the size of PLM and TAM, you can change model_name and small_model_name in run_xxx.sh scripts.

Run with a synthesized dataset

After dataset generation, we save the synthetic dataset at:

  • For TC and NLI: out-${task_name}-x2/${dataset}/${task_name}-dataset.jsonl (e.g., out-sst-2-x2/gpt2-xl_topk0_topp0.9_sst-2-x2/sst-2-dataset.jsonl). The file is in json line format (e.g., {"C": "The Book of Mormon Musical", "X": "The Book of Mormon Musical brings all the drama and excitement of a real revival of the Broadway production to the big screen.", "Y": 0}).
  • For QA: out-${task_name}-x2/${dataset}. We save the dataset in huggingface Dataset format.

To run DistilBERT given a generated dataset, you can use the scripts/run_distilbert.sh script.

To run a LSTM-based model given a generated dataset, you can use the scripts/run_cls_lstm.sh script. Before that, you have to download the datasets from google drive link, which contain the standard test files.

Diversity and Correctness of a synthesized dataset

Divesity

We use Self-BLEU to measure the diversity of a synthesized dataset. To calculate the Self-BLEU for a given dataset, you can see the example in scripts/run_self_bleu.sh script.

Correctness

To calculate the Correctness, you can take the following steps:

  1. Replace the following parameters in scripts/run_distilbert.sh script with:

    • small_model_name=roberta-large
    • dataset=: empty means using standard training set
    • limit=: empty means using full standard training set

    This will give you a RoBERTa-Large trained with full human annotations, which can be used as an evaluator.

  2. Replace the following parameters in scripts/run_distilbert.sh script with:

    • small_model_ckpt=tmp/checkpoint-xxx: the final RoBERTa-Large checkpoint saved in step 1.
    • limit=10000: the number of samples to use, by default 10000
    • dataset=xxx: the name of synthetic dataset (e.g., gpt2-xl_topk0_topp0.9_sst-2-x2)
    • no_train=true: disable training

    Run the script, and you will get Metric on standard dataset and Metric on synthetic dataset, which represents the Correctness of standard dataset and synthetic dataset, respectively.

Resources

We provide some synthetic datasets and standard datasets for training LSTM in this google drive link. When training DistilBERT, the standard dataset is directly downloaded by huggingface Dataset package. Note we use the same prompt for IMDb/SST-2, and SQuAD/AdversarialQA, therefore the synthetic datasets are also the same.

Fast EMD for Python: a wrapper for Pele and Werman's C++ implementation of the Earth Mover's Distance metric

PyEMD: Fast EMD for Python PyEMD is a Python wrapper for Ofir Pele and Michael Werman's implementation of the Earth Mover's Distance that allows it to

William Mayner 433 Dec 31, 2022
An image classification app boilerplate to serve your deep learning models asap!

Image 🖼 Classification App Boilerplate Have you been puzzled by tons of videos, blogs and other resources on the internet and don't know where and ho

Smaranjit Ghose 27 Oct 06, 2022
Unsupervised Learning of Video Representations using LSTMs

Unsupervised Learning of Video Representations using LSTMs Code for paper Unsupervised Learning of Video Representations using LSTMs by Nitish Srivast

Elman Mansimov 341 Dec 20, 2022
Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data.

Deep Learning Dataset Maker Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data. How to use Down

deepbands 25 Dec 15, 2022
sense-py-AnishaBaishya created by GitHub Classroom

Compute Statistics Here we compute statistics for a bunch of numbers. This project uses the unittest framework to test functionality. Pass the tests T

1 Oct 21, 2021
Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World

Legged Robots that Keep on Learning Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World, whic

Laura Smith 70 Dec 07, 2022
CM-NAS: Cross-Modality Neural Architecture Search for Visible-Infrared Person Re-Identification (ICCV2021)

CM-NAS Official Pytorch code of paper CM-NAS: Cross-Modality Neural Architecture Search for Visible-Infrared Person Re-Identification in ICCV2021. Vis

JDAI-CV 40 Nov 25, 2022
利用python脚本实现微信、支付宝账单的合并,并保存到excel文件实现自动记账,可查看可视化图表。

KeepAccounts_v2.0 KeepAccounts.exe和其配套表格能够实现微信、支付宝官方导出账单的读取合并,为每笔帐标记类型,并按月份和类型生成可视化图表。再也不用消费一笔记一笔,每月仅需10分钟,记好所有的帐。 作者: MickLife Bilibili: https://spac

159 Jan 01, 2023
The codes and models in 'Gaze Estimation using Transformer'.

GazeTR We provide the code of GazeTR-Hybrid in "Gaze Estimation using Transformer". We recommend you to use data processing codes provided in GazeHub.

65 Dec 27, 2022
PRIN/SPRIN: On Extracting Point-wise Rotation Invariant Features

PRIN/SPRIN: On Extracting Point-wise Rotation Invariant Features Overview This repository is the Pytorch implementation of PRIN/SPRIN: On Extracting P

Yang You 17 Mar 02, 2022
codes for "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation" (long paper of EMNLP-2022)

Scheduled Sampling Based on Decoding Steps for Neural Machine Translation (EMNLP-2021 main conference) Contents Overview Background Quick to Use Furth

Adaxry 13 Jul 25, 2022
Stochastic Normalizing Flows

Stochastic Normalizing Flows We introduce stochasticity in Boltzmann-generating flows. Normalizing flows are exact-probability generative models that

AI4Science group, FU Berlin (Frank Noé and co-workers) 50 Dec 16, 2022
Applicator Kit for Modo allow you to apply Apple ARKit Face Tracking data from your iPhone or iPad to your characters in Modo.

Applicator Kit for Modo Applicator Kit for Modo allow you to apply Apple ARKit Face Tracking data from your iPhone or iPad with a TrueDepth camera to

Andrew Buttigieg 3 Aug 24, 2021
Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers.

Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers. It contains purchases, recurring

Ayodeji Yekeen 1 Jan 01, 2022
A Python library for common tasks on 3D point clouds

Point Cloud Utils (pcu) - A Python library for common tasks on 3D point clouds Point Cloud Utils (pcu) is a utility library providing the following fu

Francis Williams 622 Dec 27, 2022
Multi-tool reverse engineering collaboration solution.

CollaRE v0.3 Intorduction CollareRE is a tool for collaborative reverse engineering that aims to allow teams that do need to use more then one tool du

105 Nov 27, 2022
Text Generation by Learning from Demonstrations

Text Generation by Learning from Demonstrations The README was last updated on March 7, 2021. The repo is based on fairseq (v0.9.?). Paper arXiv Prere

38 Oct 21, 2022
3D-printable hand-strapped keyboard

Note: This repo has not been cleaned up and prepared for general consumption at all. This is just a dump of the project files. If there is any interes

Wojciech Baranowski 41 Dec 31, 2022
For IBM Quantum Challenge Africa 2021, 9 September (07:00 UTC) - 20 September (23:00 UTC).

IBM Quantum Challenge Africa 2021 To ensure Africa is able to apply quantum computing to solve problems relevant to the continent, the IBM Research La

Qiskit Community 48 Dec 25, 2022
Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Octavio Arriaga 5.3k Dec 30, 2022