This repository presents an implementation of the Wav2Vec2 model [1] in TensorFlow 2.0 as a part of Google Summer of Code.
For a quick demo, please check out this. Final report of the project can be found here.
Notebooks
The repository comes with shiny Colab Notebooks. Below you can find a list of them. Spin them up and don't forget to have fun!
Checkpoints
Below is a summary of checkpoints obtained during the project:
|
TFHub SavedModel |
Description |
---|---|---|
gsoc-wav2vec2 |
wav2vec2 |
This checkpoint is TensorFlow's equivalent of pre-trained Wav2Vec2 by Facebook. PyTorch weights are converted into TensorFlow using convert_torch_to_tf.py |
gsoc-wav2vec2-960h |
wav2vec2-960h |
This checkpoint is TensorFlow's equivalent of fine-tuned Wav2Vec2 by Facebook. PyTorch weights are converted into TensorFlow using convert_torch_to_tf.py |
finetuned-wav2vec2-960h |
- | This checkpoint is obtained by fine-tuning Wav2Vec2 model on 960h of LibriSpeech dataset during my GSoC tenure. You can reproduce training by running main.py on TPU v3-8 |
To know more about the process of obtaining the first two checkpoints, please check out this section and to know about the process of obtaining the last checkpoint, please check out this section.
Using this Repository
Wav2Vec2
model from this repository can be installed using the pip
command:
# this will install the wav2vec2 package
pip3 install git+https://github.com/vasudevgupta7/[email protected]
You can use the fine-tuned checkpoints (from
from wav2vec2 import Wav2Vec2ForCTC, Wav2Vec2Config
config = Wav2Vec2Config()
model = Wav2Vec2ForCTC(config)
# now use this model like any other TF model
# incase you are interested in already trained model, use `.from_pretrained` method
model_id = "finetuned-wav2vec2-960h"
model = Wav2Vec2ForCTC.from_pretrained(model_id)
Additionally, you can use the SavedModel
from TFHub like this:
import tensorflow_hub as hub
model_url = "https://tfhub.dev/vasudevgupta7/wav2vec2-960h/1"
model = hub.KerasLayer(model_url)
# use this `model`, just like any other TF SavedModel
Please checkout the notebooks referred to in this repository for more information on how to use the Wav2Vec2
model.
Reproducing this project
Setting Up
# install & setup TensorFlow first
pip3 install tensorflow
# install other requirements of this project using the following command:
pip3 install -qr requirements.txt
sudo apt-get install libsndfile1-dev
# switch to code directory for further steps
cd src
For using TPUs, it's important to store model weights and datasets in the GCS bucket so that TPU can access them directly from there. Hence we will create 2 GCS buckets - one for checkpointing and the other for storing LibriSpeech tfrecords.
# these bucket names will be required to run the training script later
export DATA_BUCKET_NAME="gsoc-librispeech-us"
export CKPT_BUCKET_NAME="gsoc-checkpoints-us"
# create GCS buckets
gsutil mb gs://${DATA_BUCKET_NAME}
gsutil mb gs://${CKPT_BUCKET_NAME}
Preparing dataset
Now we will download the LibriSpeech dataset from the official website & convert them into tfrecords using make_tfrecords.py
. Finally, we will export all the tfrecords to the GCS bucket.
# possible values are `dev-clean`, `train-clean-100`, `train-clean-360`, `train-other-500`, `test-clean`
# you will have to follow same steps for all the configurations (specified above).
export DATA_SPLIT=dev-clean
wget https://www.openslr.org/resources/12/${DATA_SPLIT}.tar.gz
tar -xf ${DATA_SPLIT}.tar.gz
python3 make_tfrecords.py --data_dir LibriSpeech/${DATA_SPLIT} -d ${DATA_SPLIT} -n 50
# transfer tfrecords to GCS bucket
gsutil cp -r ${DATA_SPLIT} gs://<DATA_BUCKET_NAME>/${DATA_SPLIT}
Now your GCS bucket (DATA_BUCKET_NAME
) should look like this:
.
|- ${DATA_SPLIT}
|- ${DATA_SPLIT}-0.tfrecord
|- ${DATA_SPLIT}-1.tfrecord
.
.
Follow the above steps for all other data splits. You just need to change the DATA_SPLIT
environment variable.
Model training
Now since everything is installed and GCS buckets are configured, we just need to run one command to initiate training.
Note: Following commands assumes that you have exported DATA_BUCKET_NAME
& CKPT_BUCKET_NAME
environment variables already.
The following command will fine-tune the wav2vec2 model on single/multiple GPUs or Colab/Kaggle TPUs:
python3 main.py
For training on Cloud TPUs, run the following command:
# export `TPU_NAME` environment variable first
# this flag will ensure that your VM connects to the specified TPUs & TPUs become visible to TensorFlow
TPU_NAME=<tpu-name> python3 main.py
Running Conversion script
Original PyTorch checkpoints (from Facebook) can be converted using the conversion script available in this repository.
python3 convert_torch_to_tf.py \
--hf_model_id facebook/wav2vec2-base \ # HuggingFace Hub ID of the model you want to convert
--with_lm_head # Whether to use `Wav2Vec2ForCTC` or `Wav2Vec2Model` from this repository
Running tests
# first install `torch` & `transformers`
pip3 install torch transformers
# run this from the root of this repository
pytest -sv tests
Acknowledgement
- Sayak Paul, Morgan Roff, Jaeyoun Kim for mentoring me throughout the project.
- TensorFlow team & TRC for providing access to TPUs during my GSoC tenure.
References
[1] Baevski, Alexei, et al. “Wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.” ArXiv:2006.11477 [Cs, Eess], Oct. 2020. arXiv.org, http://arxiv.org/abs/2006.11477.
End Notes
Please create an issue in case you encountered any issues while using this project. Don't forget to