Train ๐Ÿค—transformers with DeepSpeed: ZeRO-2, ZeRO-3

Overview

Fork from https://github.com/huggingface/transformers/tree/86d5fb0b360e68de46d40265e7c707fe68c8015b/examples/pytorch/language-modeling at 2021.05.17.

Language model training

Fine-tuning (or training from scratch) the library models for language modeling on a text dataset for GPT, GPT-2, ALBERT, BERT, DistilBERT, RoBERTa, XLNet... GPT and GPT-2 are trained or fine-tuned using a causal language modeling (CLM) loss while ALBERT, BERT, DistilBERT and RoBERTa are trained or fine-tuned using a masked language modeling (MLM) loss. XLNet uses permutation language modeling (PLM), you can find more information about the differences between those objectives in our model summary.

There are two sets of scripts provided. The first set leverages the Trainer API. The second set with no_trainer in the suffix uses a custom training loop and leverages the ๐Ÿค— Accelerate library . Both sets use the ๐Ÿค— Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.

Note: The old script run_language_modeling.py is still available here.

The following examples, will run on datasets hosted on our hub or with your own text files for training and validation. We give examples of both below.

GPT-2/GPT and causal language modeling

The following example fine-tunes GPT-2 on WikiText-2. We're using the raw WikiText-2 (no tokens were replaced before the tokenization). The loss here is that of causal language modeling.

python run_clm.py \
    --model_name_or_path gpt2 \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-clm

This takes about half an hour to train on a single K80 GPU and about one minute for the evaluation to run. It reaches a score of ~20 perplexity once fine-tuned on the dataset.

To run on your own training and validation files, use the following command:

python run_clm.py \
    --model_name_or_path gpt2 \
    --train_file path_to_train_file \
    --validation_file path_to_validation_file \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-clm

This uses the built in HuggingFace Trainer for training. If you want to use a custom training loop, you can utilize or adapt the run_clm_no_trainer.py script. Take a look at the script for a list of supported arguments. An example is shown below:

python run_clm_no_trainer.py \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --model_name_or_path gpt2 \
    --output_dir /tmp/test-clm

RoBERTa/BERT/DistilBERT and masked language modeling

The following example fine-tunes RoBERTa on WikiText-2. Here too, we're using the raw WikiText-2. The loss is different as BERT/RoBERTa have a bidirectional mechanism; we're therefore using the same loss that was used during their pre-training: masked language modeling.

In accordance to the RoBERTa paper, we use dynamic masking rather than static masking. The model may, therefore, converge slightly slower (over-fitting takes more epochs).

python run_mlm.py \
    --model_name_or_path roberta-base \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-mlm

To run on your own training and validation files, use the following command:

python run_mlm.py \
    --model_name_or_path roberta-base \
    --train_file path_to_train_file \
    --validation_file path_to_validation_file \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-mlm

If your dataset is organized with one sample per line, you can use the --line_by_line flag (otherwise the script concatenates all texts and then splits them in blocks of the same length).

This uses the built in HuggingFace Trainer for training. If you want to use a custom training loop, you can utilize or adapt the run_mlm_no_trainer.py script. Take a look at the script for a list of supported arguments. An example is shown below:

python run_mlm_no_trainer.py \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --model_name_or_path roberta-base \
    --output_dir /tmp/test-mlm

Note: On TPU, you should use the flag --pad_to_max_length in conjunction with the --line_by_line flag to make sure all your batches have the same length.

Whole word masking

This part was moved to examples/research_projects/mlm_wwm.

XLNet and permutation language modeling

XLNet uses a different training objective, which is permutation language modeling. It is an autoregressive method to learn bidirectional contexts by maximizing the expected likelihood over all permutations of the input sequence factorization order.

We use the --plm_probability flag to define the ratio of length of a span of masked tokens to surrounding context length for permutation language modeling.

The --max_span_length flag may also be used to limit the length of a span of masked tokens used for permutation language modeling.

Here is how to fine-tune XLNet on wikitext-2:

python run_plm.py \
    --model_name_or_path=xlnet-base-cased \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-plm

To fine-tune it on your own training and validation file, run:

python run_plm.py \
    --model_name_or_path=xlnet-base-cased \
    --train_file path_to_train_file \
    --validation_file path_to_validation_file \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-plm

If your dataset is organized with one sample per line, you can use the --line_by_line flag (otherwise the script concatenates all texts and then splits them in blocks of the same length).

Note: On TPU, you should use the flag --pad_to_max_length in conjunction with the --line_by_line flag to make sure all your batches have the same length.

Owner
Junbum Lee
AI/NLP Researcher who loves python, web, and data. And who believes technology would make the world a better place.
Junbum Lee
Code for Findings at EMNLP 2021 paper: "Learn Continually, Generalize Rapidly: Lifelong Knowledge Accumulation for Few-shot Learning"

Learn Continually, Generalize Rapidly: Lifelong Knowledge Accumulation for Few-shot Learning This repo is for Findings at EMNLP 2021 paper: Learn Cont

INK Lab @ USC 6 Sep 02, 2022
FastFormers - highly efficient transformer models for NLU

FastFormers FastFormers provides a set of recipes and methods to achieve highly efficient inference of Transformer models for Natural Language Underst

Microsoft 678 Jan 05, 2023
Live Speech Portraits: Real-Time Photorealistic Talking-Head Animation (SIGGRAPH Asia 2021)

Live Speech Portraits: Real-Time Photorealistic Talking-Head Animation This repository contains the implementation of the following paper: Live Speech

OldSix 575 Dec 31, 2022
A BERT-based reverse-dictionary of Korean proverbs

Wisdomify A BERT-based reverse-dictionary of Korean proverbs. ๊น€์œ ๋นˆ : ๋ชจ๋ธ๋ง / ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ / ํ”„๋กœ์ ํŠธ ์„ค๊ณ„ / back-end ๊น€์ข…์œค : ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ / ํ”„๋กœ์ ํŠธ ์„ค๊ณ„ / front-end Quick Start C

Eu-Bin KIM 94 Dec 08, 2022
An open source library for deep learning end-to-end dialog systems and chatbots.

DeepPavlov is an open-source conversational AI library built on TensorFlow, Keras and PyTorch. DeepPavlov is designed for development of production re

Neural Networks and Deep Learning lab, MIPT 6k Dec 30, 2022
SciBERT is a BERT model trained on scientific text.

SciBERT is a BERT model trained on scientific text.

AI2 1.2k Dec 24, 2022
Meta learning algorithms to train cross-lingual NLI (multi-task) models

Meta learning algorithms to train cross-lingual NLI (multi-task) models

M.Hassan Mojab 4 Nov 20, 2022
MMDA - multimodal document analysis

MMDA - multimodal document analysis

AI2 75 Jan 04, 2023
๋‰ด์Šค ๋„๋ฉ”์ธ ์งˆ์˜์‘๋‹ต ์‹œ์Šคํ…œ (21-1ํ•™๊ธฐ ์กธ์—… ํ”„๋กœ์ ํŠธ)

๋‰ด์Šค ๋„๋ฉ”์ธ ์งˆ์˜์‘๋‹ต ์‹œ์Šคํ…œ ๋ณธ ํ”„๋กœ์ ํŠธ๋Š” ๋‰ด์Šค๊ธฐ์‚ฌ์— ๋Œ€ํ•œ ์งˆ์˜์‘๋‹ต ์„œ๋น„์Šค ๋ฅผ ์ œ๊ณตํ•˜๊ธฐ ์œ„ํ•ด์„œ ์ง„ํ–‰ํ•œ ํ”„๋กœ์ ํŠธ์ž…๋‹ˆ๋‹ค. ์•ฝ 3๊ฐœ์›”๊ฐ„ ( 21. 03 ~ 21. 05 ) ์ง„ํ–‰ํ•˜์˜€์œผ๋ฉฐ Transformer ์•„ํ‚คํ…์ณ ๊ธฐ๋ฐ˜์˜ Encoder๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•œ๊ตญ์–ด ์งˆ์˜์‘๋‹ต ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ

TaegyeongEo 4 Jul 08, 2022
Pytorch-Named-Entity-Recognition-with-BERT

BERT NER Use google BERT to do CoNLL-2003 NER ! Train model using Python and Inference using C++ ALBERT-TF2.0 BERT-NER-TENSORFLOW-2.0 BERT-SQuAD Requi

Kamal Raj 1.1k Dec 25, 2022
Interactive Jupyter Notebook Environment for using the GPT-3 Instruct API

gpt3-instruct-sandbox Interactive Jupyter Notebook Environment for using the GPT-3 Instruct API Description This project updates an existing GPT-3 san

312 Jan 03, 2023
NLP, Machine learning

Netflix-recommendation-system NLP, Machine learning About Recommendation algorithms are at the core of the Netflix product. It provides their members

Harshith VH 6 Jan 12, 2022
GooAQ ๐Ÿฅ‘ : Google Answers to Google Questions!

This repository contains the code/data accompanying our recent work on long-form question answering.

AI2 112 Nov 06, 2022
DELTA is a deep learning based natural language and speech processing platform.

DELTA - A DEep learning Language Technology plAtform What is DELTA? DELTA is a deep learning based end-to-end natural language and speech processing p

DELTA 1.5k Dec 26, 2022
NLP techniques such as named entity recognition, sentiment analysis, topic modeling, text classification with Python to predict sentiment and rating of drug from user reviews.

This file contains the following documents sumbited for Baruch CIS9665 group 9 fall 2021. 1. Dataset: drug_reviews.csv 2. python codes for text classi

Aarif Munwar Jahan 2 Jan 04, 2023
HiFi DeepVariant + WhatsHap workflowHiFi DeepVariant + WhatsHap workflow

HiFi DeepVariant + WhatsHap workflow Workflow steps align HiFi reads to reference with pbmm2 call small variants with DeepVariant, using two-pass meth

William Rowell 2 May 14, 2022
This project uses unsupervised machine learning to identify correlations between daily inoculation rates in the USA and twitter sentiment in regards to COVID-19.

Twitter COVID-19 Sentiment Analysis Members: Christopher Bach | Khalid Hamid Fallous | Jay Hirpara | Jing Tang | Graham Thomas | David Wetherhold Pro

4 Oct 15, 2022
Automatic privilege escalation for misconfigured capabilities, sudo and suid binaries

GTFONow Automatic privilege escalation for misconfigured capabilities, sudo and suid binaries. Features Automatically escalate privileges using miscon

101 Jan 03, 2023
Repository for Graph2Pix: A Graph-Based Image to Image Translation Framework

Graph2Pix: A Graph-Based Image to Image Translation Framework Installation Install the dependencies in env.yml $ conda env create -f env.yml $ conda a

18 Nov 17, 2022
Refactored version of FastSpeech2

Refactored version of FastSpeech2. An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"

ILJI CHOI 10 May 26, 2022