Tutorial to pretrain & fine-tune a 🤗 Flax T5 model on a TPUv3-8 with GCP

Overview

Pretrain and Fine-tune a T5 model with Flax on GCP

This tutorial details how pretrain and fine-tune a FlaxT5 model from HuggingFace using a TPU VM available on Google Cloud.

While the code is only slightly adapted from the original HuggingFace examples for pretraining and seq2seq fine-tuning, this repository is aimed to provide a comprehensive overview for the whole process, with a special focus on pitfalls due to an incorrect environment setup.

Why JAX/Flax? Thanks to the amazing work of the HuggingFace team, a good portion of the models available in the transformers library are available in Flax, a neural network library built on top of JAX. Using Flax for pretraining on TPUs is especially convenient for two reasons:

  • TPUs are optimized for training deep models, and Flax is optimized for TPUs. A benchmark conducted by Huggingface showed that a BERT pretraining with Flax takes only 65% of the time needed with Pytorch/XLA to achieve comparable performances on a TPUv3-8 VM, and only 35% of the time if compared to Pytorch training on 8 V100 GPUs. Since language models are usually pretrained for days on massive amounts of text, the speedup is especially desirable here.

  • Out-of-the-box compatibility with Pytorch and Tensorflow in the transformers framework. A Flax model can be easily converted in Pytorch, for example, by using T5ForConditionalGeneration.from_pretrained("path/to/flax/ckpt", from_flax=True).

The code and instructions contained in this repository were used to pretrain the models gsarti/t5-base-it and gsarti/t5-large-it available on the Huggingface Hub, using ~270Gb of cleaned web-scraped Italian texts. The dataset is also made available on the Huggingface Hub under the name gsarti/clean_mc4_it.

Setup on your machine

Follow the instructions detailed in the Cloud SDK Install Guide to install the gcloud client on your system.

TPU VMs only have ~100Gb of disk space, which is highly unlikely to be enough to store your raw + preprocessed dataset and all model checkpoints. GCP gives a choice among different kind of disk types and limits the total disk space a user can create. At the time of writing this, for a non-free trial user in Europe without special access the limit is 250Gb for SSD (pd-balanced) and 2Tb overall, including SSD and HDD (pd-standard).

### Define your variables
export GCP_PROJECT="<YOUR_PROJECT_NAME>"
export GCP_ZONE="<YOUR_REGION>"
export GCP_TPU_NAME="<YOUR_TPU_NAME>"

# >>>>> Run this part to add a disk to your TPU VM
export GCP_DISK_NAME="<YOUR_DISK_NAME>"
export GCP_DISK_SIZE_GB=1200
export GCP_DISK_TYPE=pd-standard


gcloud beta compute disks create $GCP_DISK_NAME \
    --project=$GCP_PROJECT \
    --type=$GCP_DISK_TYPE \
    --size="${GCP_DISK_SIZE_GB}GB" \
    --zone=$GCP_ZONE
# <<<<<

# Create the TPU VM
gcloud alpha compute tpus tpu-vm create $GCP_TPU_NAME \
    --zone $GCP_ZONE \
    --project $GCP_PROJECT \
    --accelerator-type v3-8 \
    --version v2-alpha \
    # Uncomment this if a disk is used
    #--data-disk source="projects/${GCP_PROJECT}/zones/${GCP_ZONE}/disks/${GCP_DISK_NAME}" 

Inside the TPUv3-8 machine

Log in inside the machine with the following command:

gcloud alpha compute tpus tpu-vm ssh $GCP_TPU_NAME --zone $GCP_ZONE --project $GCP_PROJECT

The system is a classic Ubuntu 20.04 with some utility libraries already available (e.g. tmux, vim). The following commands show how to setup the disk inside the machine, if you created a disk. The disk in the example is available on /dev/sdb, and we mount it on the data folder in the home:

# Check that your disk is visible and get its name
lsblk
# Mount disk in the data folder
sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb
sudo mkdir -p data
sudo mount -o discard,defaults /dev/sdb data
sudo chmod a+w data
# Optionally, add the automatic mounting: https://cloud.google.com/compute/docs/disks/add-persistent-disk#configuring_automatic_mounting_on_vm_restart

# Permanently export the disk as the new path for HF dataset caching
echo "export HF_DATASETS_CACHE=data" >> .bashrc
source .bashrc

We then setup the environment with required libraries and dependencies:

### Setup the TPUv3-8 environment
sudo apt update
sudo apt-get install python3.8-venv
python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -r requirements.txt
sudo apt-get install git-lfs
git lfs install

The next step is to create a repository on the Huggingface Hub (here e.g. t5-base-it) that will contain all our files. We then train a tokenizer, select a config, and push the two JSON files in the repository we created.

Depending on the number of sentences that need to be processed, this may take quite some time (~2 hours for 10M sentences).

# Huggingface Hub variables
export HF_NAME="<YOUR_HF_HUB_USERNAME>"
export HF_PWD="<YOUR_HF_HUB_PASSWORD>"
export HF_TOKEN="<YOUR_HF_HUB_API_TOKEN>"
export HF_PROJECT="t5-base-it"

# Variables for training the tokenizer and creating the config
export VOCAB_SIZE="32000"
export N_INPUT_SENTENCES="1000000" # Num of sentences to train the tokenizer
export DATASET="gsarti/clean_mc4_it" # Name of the dataset in the Huggingface Hub
export DATASET_CONFIG="full" # Config of the dataset in the Huggingface Hub 
export DATASET_SPLIT="train" # Split to use for training tokenizer and model
export TEXT_FIELD="text" # Field containing the text to be used for training
export CONFIG_TYPE="google/t5-v1_1-base" # Config that our model will use
export MODEL_PATH="data/${HF_PROJECT}" # Path to the model, e.g. here inside the mount

### Setup the project on the HF Hub
#huggingface-cli login also works
echo $HF_TOKEN > ~/.huggingface/token
huggingface-cli repo create $HF_PROJECT
git clone "https://huggingface.co/${HF_NAME}/${HF_PROJECT}"
mv $HF_PROJECT $MODEL_PATH

# Create the tokenizer and the config
python create_tokenizer_cfg.py \
    --model_dir $MODEL_PATH \
    --dataset $DATASET \
    --dataset_config $DATASET_CONFIG \
    --dataset_split $DATASET_SPLIT \
    --text_field $TEXT_FIELD \
    --vocab_size $VOCAB_SIZE \
    --input_sentence_size $N_INPUT_SENTENCES \
    --config_type $CONFIG_TYPE

# Push the tokenizer and the config to the HF Hub
cd $MODEL_PATH
git lfs track "*tfevents*"
git add .
git commit -m "Added tokenizer and config"à
# If this doesn't work, do a normal push and insert your credentials
git push "https://huggingface.co/${HF_NAME}/${HF_NAME}:${HF_PWD}@${HF_PROJECT}"
cd ..

The last step is to run the pretraining. The following command will train the model on the same dataset used to train the tokenizer, logging train/eval metrics to Tensorboard and pushing model checkpoints and artifacts to the Hub every save_steps steps:

Important: Since checkpoints are stored with Git LFS on the Hub, depending on the number of save_steps you may encur in problems with space (e.g. training a model with 1Gb ckpts for 1M steps with save_steps=10_000 will need 100Gb of free space to avoid crashes during training). Two ways to avoid problems are:

  • Use save_steps > (tot_steps / your_available_memory) * ckpt_size_gb. E.g. if we are doing 1M steps and the ckpt_size is 1GB and we only have 76Gb free, we will need save_steps higher than 1M/76Gb ~= 13158. Failing to do so will result in an interruption of the training once the memory has been exceeded.

  • Use git lfs prune --verify-remote to periodically remove old cached checkpoints from .git/lfs/objects and free up some space. If done often enough this will prevent any problem of space, but requires manual intervention of the user.

# Run the training
# Adjust the parameters to your needs
# To avoid using an extra disk for storing the model, 
python run_t5_mlm_flax.py \
    --output_dir=$MODEL_PATH \
    --model_type="t5" \
    --config_name=$MODEL_PATH \
    --tokenizer_name=$MODEL_PATH \
    --preprocessing_num_workers="96" \
    --do_train --do_eval \
    --adafactor \
    --dataset_name="gsarti/clean_mc4_it" \
    --dataset_config_name="full" \
    --max_seq_length="512" \
    --per_device_train_batch_size="8" \
    --per_device_eval_batch_size="8" \
    --learning_rate="0.005" \
    --overwrite_output_dir \
    --num_train_epochs="1" \
    --logging_steps="500" \
    --save_steps="80000" \
    --eval_steps="2500" \
    --weight_decay="0.01" \
    --warmup_steps="10000" \
    --validation_split_count="15000" \
    --push_to_hub \
    #--gradient_accumulation_steps="2" # Uncomment to use gradient accumulation
    #--resume_from_checkpoint=$MODEL_PATH # Uncomment to resume from ckpt

Exporting Checkpoints

Your model is now trained, congratulations! 🎉 Now you may want to export your final Flax model to other frameworks. This can be done easily:

python export_checkpoint.py --model_dir $MODEL_PATH
cd $MODEL_PATH
git commit -m "Added TF and PT models"
git push

Fine-tune a pretrained T5 model in Flax

TODO

Useful Tips

About GCP Billing Except for TPU costs, which can be offset by taking part in the TRC program, the second highest billing voice is the uploading of files from the VM to another destination. This happens every save_step during training, to keep the Tensorboard and the checkpoint on the HF Hub in sync. For in-region uploads, at the time of writing this the cost is up to 0.12 USD per GB. This means that for a full pretraining of 1M steps with logging every 10k steps, with checkpoints of roughly 1GB each, the cost is roughly 12$USD. To reduce this cost simply increase the amount of save_steps in the run_t5_mlm_flax.py script.

About data preprocessing Imagine you've trained a base model and now you want to train a large variant, using the same tokenzer. Even if the tokenizer is the same, if the path of the tokenizer is changed the cache will be invalidated and everything will be recomputed. If this is the case, it's not a tragedy: some extra preprocessing time will be required. But mind that the tokenization + grouping take substantial space on disk, so if the space is not enough, you will need to delete the arrows from the previous HuggingFace dataset cache. You can easily do so with something like:

find . -type f -size +2G -size -4G -exec ls -lah {} + | grep 'Aug 25 13' | wc -l

As an example, knowing that my tokenization completed Aug 25 at 13:xx and that files in the cache are 3Gb in size each, this should return the value you set for n_proc. You can get the filenames of the old preprocessing caches by removing the wc -l pipe and delete them all to leave space for the new ones.

Owner
Gabriele Sarti
PhD Student in NLP @ University of Groningen | Prev: Aindo, ILC-CNR Pisa, Skytech Montreal
Gabriele Sarti
Samantha, A covid-19 information bot which will provide basic information about this pandemic in form of conversation.

Covid-19-BOT Samantha, A covid-19 information bot which will provide basic information about this pandemic in form of conversation. This bot uses torc

Neeraj Majhi 2 Nov 05, 2021
A paper list of pre-trained language models (PLMs).

Large-scale pre-trained language models (PLMs) such as BERT and GPT have achieved great success and become a milestone in NLP.

RUCAIBox 124 Jan 02, 2023
Simple python code to fix your combo list by removing any text after a separator or removing duplicate combos

Combo List Fixer A simple python code to fix your combo list by removing any text after a separator or removing duplicate combos Removing any text aft

Hamidreza Dehghan 3 Dec 05, 2022
Code and datasets for our paper "PTR: Prompt Tuning with Rules for Text Classification"

PTR Code and datasets for our paper "PTR: Prompt Tuning with Rules for Text Classification" If you use the code, please cite the following paper: @art

THUNLP 118 Dec 30, 2022
A library for end-to-end learning of embedding index and retrieval model

Poeem Poeem is a library for efficient approximate nearest neighbor (ANN) search, which has been widely adopted in industrial recommendation, advertis

54 Dec 21, 2022
The swas programming language

The Swas programming language This is a language that was made for fun. Installation Step 0: Make sure you have python installed Step 1. Clone this re

Swas.py 19 Jul 18, 2022
Implementation of Fast Transformer in Pytorch

Fast Transformer - Pytorch Implementation of Fast Transformer in Pytorch. This only work as an encoder. Yannic video AI Epiphany Install $ pip install

Phil Wang 167 Dec 27, 2022
Awesome Treasure of Transformers Models Collection

💁 Awesome Treasure of Transformers Models for Natural Language processing contains papers, videos, blogs, official repo along with colab Notebooks. 🛫☑️

Ashish Patel 577 Jan 07, 2023
AI_Assistant - This is a Python based Voice Assistant.

This is a Python based Voice Assistant. This was programmed to increase my understanding of python and also how the in-general Voice Assistants work.

1 Jan 06, 2022
Awesome-NLP-Research (ANLP)

Awesome-NLP-Research (ANLP)

Language, Information, and Learning at Yale 72 Dec 19, 2022
⛵️The official PyTorch implementation for "BERT-of-Theseus: Compressing BERT by Progressive Module Replacing" (EMNLP 2020).

BERT-of-Theseus Code for paper "BERT-of-Theseus: Compressing BERT by Progressive Module Replacing". BERT-of-Theseus is a new compressed BERT by progre

Kevin Canwen Xu 284 Nov 25, 2022
Entity Disambiguation as text extraction (ACL 2022)

ExtEnD: Extractive Entity Disambiguation This repository contains the code of ExtEnD: Extractive Entity Disambiguation, a novel approach to Entity Dis

Sapienza NLP group 121 Jan 03, 2023
基于pytorch+bert的中文事件抽取

pytorch_bert_event_extraction 基于pytorch+bert的中文事件抽取,主要思想是QA(问答)。 要预先下载好chinese-roberta-wwm-ext模型,并在运行时指定模型的位置。

西西嘛呦 31 Nov 30, 2022
PyTorch source code of NAACL 2019 paper "An Embarrassingly Simple Approach for Transfer Learning from Pretrained Language Models"

This repository contains source code for NAACL 2019 paper "An Embarrassingly Simple Approach for Transfer Learning from Pretrained Language Models" (P

Alexandra Chronopoulou 89 Aug 12, 2022
Code for the paper PermuteFormer

PermuteFormer This repo includes codes for the paper PermuteFormer: Efficient Relative Position Encoding for Long Sequences. Directory long_range_aren

Peng Chen 42 Mar 16, 2022
Chinese Named Entity Recognization (BiLSTM with PyTorch)

BiLSTM-CRF for Name Entity Recognition PyTorch version A PyTorch implemention of Bi-LSTM-CRF model for Chinese Named Entity Recognition. 使用 PyTorch 实现

5 Jun 01, 2022
中文医疗信息处理基准CBLUE: A Chinese Biomedical LanguageUnderstanding Evaluation Benchmark

English | 中文说明 CBLUE AI (Artificial Intelligence) is playing an indispensabe role in the biomedical field, helping improve medical technology. For fur

452 Dec 30, 2022
SEJE is a prototype for the paper Learning Text-Image Joint Embedding for Efficient Cross-Modal Retrieval with Deep Feature Engineering.

SEJE is a prototype for the paper Learning Text-Image Joint Embedding for Efficient Cross-Modal Retrieval with Deep Feature Engineering. Contents Inst

0 Oct 21, 2021
Train 🤗transformers with DeepSpeed: ZeRO-2, ZeRO-3

Fork from https://github.com/huggingface/transformers/tree/86d5fb0b360e68de46d40265e7c707fe68c8015b/examples/pytorch/language-modeling at 2021.05.17.

Junbum Lee 12 Oct 26, 2022
Trex is a tool to match semantically similar functions based on transfer learning.

Trex is a tool to match semantically similar functions based on transfer learning.

62 Dec 28, 2022