PyTorch implementation of "data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language" from Meta AI

Overview

data2vec-pytorch

PyTorch implementation of "data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language" from Meta AI (FAIR)

Data2Vec is the first high-performance self-supervised algorithm that learns the same way in multiple modalities, including speech, vision and text. Most machines learn exclusively from labeled data. However, through self-supervised learning, machines are able to learn about the world just by observing it and then figuring out the structure of images, speech or text. This is a more scalable and efficient approach for machines to tackle new complex tasks, such as understanding text for more spoken languages.

In summary, the method is as follows:

  1. The encoder extracts features from the masked inputs. These features are outputs of every transformer/linear layer.
  2. The teacher which is an EMA instance of the encoder (in eval model), extracts features from the unmasked inputs.
  3. Optional normalizations are applied to the layers/outputs of the teacher.
  4. Encoder outputs are regressed by a projection block/layer.
  5. The loss is calculated from encoder outputs and teacher outputs.

You can read the paper for more detail.

Implementation

Data2Vec is already implemented in fairseq in which for all modalities there is a seperate implementation (text, vision, audio). According to the paper:

Our primary is to design a single learning mechanism for different modalities. Despite the unified learning regime, we still use modality-specific features extractors and masking strategies. This makes sense given the vastly different nature of the input data.

This implementation differs in the fact that a single Data2Vec model is provided powered by a custom encoder (implemented using PyTorch + HuggingFace Transformers) and tries to unify the whole concept in a single module. The key concept is that there must be modality-specific feature extractions and masking strategies.

  • Masking: For each modality, the Dataset instance must return the masked source, the target and the mask tensor.

  • Feature Extraction: Features are the outputs from the transformer/attention layers. So the forward method must return outputs from all Encoder blocks of the transformer model. HuggingFace Transformers/Fairseq models return transformer layers outputs separately out of the box.

This implementation uses HuggingFace Transformers models as encoders for Data2Vec which you can inspect in the encoder.py files for each modality. Although, you can provide your own encoder model. Just make sure that your encoder must be Transformer-based according to the paper and outputs from every encoder layer must be provided.

Note: This implementation's goal is to provide the necessary building blocks of Data2Vec so anyone can adapt it to their own use case with ease, so in order to make it easy to get hands on, some functionalities like mixed precision, distributed training, etc are not included to keep it as clean & simple as possible. If you only need to train a standard large scale Data2Vec model use the official repo.

Train

First things first, install the requirements:

pip install -r requirements.txt

NLP

Train a Language Model based on RoBERTa (HuggingFace) on WikiText103

Configure the related properties in text/configs/roberta-pretraining.yaml and run:

python train.py --config text/configs/roberta-pretraining.yaml 

Vision

Run a Masked Image modeling training based on BEiT (HuggingFace)

Pass the path to the image dataset in the config file at vision/configs/beit-pretraining.yaml under dataset > path > train/test and modify other properties as you desire and run the following:

python train.py --config vision/configs/beit-pretraining.yaml 

Speech

Audio pretraining based on Wav2Vec2 (HuggingFace) on timit dataset. If you want to use other datasets like librispeech provide it in audio/dataset.py (some minor changes to the timit class would do the job because both are loaded from HuggingFace datasets)

Configure other properties as you desire and run the following:

python train.py --config audio/configs/wav2vec2-pretraining.yaml 

Pre-trained Weights

The models are available on HuggingFace Hub and you can use them like below:

RoBERTa

Data2Vec model trained with RoBERTa as the encoder (data2vec-roberta-base)

from transformers import AutoModel, AutoConfig
from transformers import RobertaModel

checkpoint = 'arxyzan/data2vec-roberta-base'

# Option 1: load using AutoModel
data2vec_roberta = AutoModel.from_pretrained(checkpoint)

# Option 2: load directly by RobertaModel
data2vec_roberta = RobertaModel.from_pretrained(checkpoint)

BEiT

Data2Vec model trained with BEiT as the encoder (data2vec-beit-base)

from transformers import AutoModel, AutoConfig
from transformers import BeitModel

checkpoint = 'arxyzan/data2vec-beit-base'

# Option 1: load using AutoModel
data2vec_beit = AutoModel.from_pretrained(checkpoint)

# Option 2: load directly by BeitModel
data2vec_beit = BeitModel.from_pretrained(checkpoint)

Wav2Vec2

Data2Vec model trained with Wav2Vec2 as the encoder (data2vec-wav2vec2-base)

from transformers import AutoModel, AutoConfig
from transformers import Wav2Vec2Model

checkpoint = 'arxyzan/data2vec-wav2vec2-base'

# Option 1: load using AutoModel
data2vec_wav2vec2 = AutoModel.from_pretrained(checkpoint)

# Option 2: load directly by Wav2Vec2Model
data2vec_wav2vec2 = Wav2Vec2Model.from_pretrained(checkpoint)

Note: The above models' weights were carefully ported from the original checkpoints in the fairseq version.

Fine-tuning

  1. Fine-tune using the checkpoints mentioned above:
# Text classification using Roberta model from HuggingFace
from transformers import RobertaModel, RobertaForSequenceClassification

checkpoint = 'arxyzan/data2vec-roberta-base'
# this is exactly a roberta model but trained with data2vec
data2vec_roberta = RobertaModel.from_pretrained(checkpoint)
text_classifier = RobertaForSequenceClassification(data2vec_roberta.config)
# assign `data2vec-roberta` weights to the roberta block of the classifier
text_classifier.roberta = data2vec_roberta
...
  1. In case you trained a model using this codebase, you can fine-tune it by taking out the encoder's state dict from the checkpoint which gives you a HuggingFace model and you can fine-tune it for any downstream task as you'd normally do for HuggingFace models.
# load a checkpoint for finetuning
from transformers import RobertaModel, RobertaConfig
roberta = RobertaModel(RobertaConfig())
checkpoint = torch.load('path/to/data2vec.pt')
roberta_state_dict = checkpoint['encoder']
# load roberta weights from the encoder part of the data2vec model
encoder = roberta.load_state_dict(roberta_state_dict)

# Now fine-tune a regular HuggingFace RoBERTa model
...

Contributions

Any contribution regarding training, development and issues are welcome!

Owner
Aryan Shekarlaban
Deep Learning Developer & Researcher
Aryan Shekarlaban
Spert NLP Relation Extraction API deployed with torchserve for inference

URLMask Python program for Linux users to change a URL to ANY domain. A program than can take any url and mask it to any domain name you like. E.g. ne

Zichu Chen 1 Nov 24, 2021
無料で使える中品質なテキスト読み上げソフトウェア、VOICEVOXの音声合成エンジン

VOICEVOX ENGINE VOICEVOXの音声合成エンジン。 実態は HTTP サーバーなので、リクエストを送信すればテキスト音声合成できます。 API ドキュメント VOICEVOX ソフトウェアを起動した状態で、ブラウザから

Hiroshiba 3 Jul 05, 2022
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.

Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language mod

20.5k Jan 08, 2023
FastFormers - highly efficient transformer models for NLU

FastFormers FastFormers provides a set of recipes and methods to achieve highly efficient inference of Transformer models for Natural Language Underst

Microsoft 678 Jan 05, 2023
A fast Text-to-Speech (TTS) model. Work well for English, Mandarin/Chinese, Japanese, Korean, Russian and Tibetan (so far). 快速语音合成模型,适用于英语、普通话/中文、日语、韩语、俄语和藏语(当前已测试)。

简体中文 | English 并行语音合成 [TOC] 新进展 2021/04/20 合并 wavegan 分支到 main 主分支,删除 wavegan 分支! 2021/04/13 创建 encoder 分支用于开发语音风格迁移模块! 2021/04/13 softdtw 分支 支持使用 Sof

Atomicoo 161 Dec 19, 2022
The PyTorch based implementation of continuous integrate-and-fire (CIF) module.

CIF-PyTorch This is a PyTorch based implementation of continuous integrate-and-fire (CIF) module for end-to-end (E2E) automatic speech recognition (AS

Minglun Han 24 Dec 29, 2022
All the code I wrote for Overwatch-related projects that I still own the rights to.

overwatch_shit.zip This is (eventually) going to contain all the software I wrote during my five-year imprisonment stay playing Overwatch. I'll be add

zkxjzmswkwl 2 Dec 31, 2021
CoNLL-English NER Task (NER in English)

CoNLL-English NER Task en | ch Motivation Course Project review the pytorch framework and sequence-labeling task practice using the transformers of Hu

Kevin 2 Jan 14, 2022
Original implementation of the pooling method introduced in "Speaker embeddings by modeling channel-wise correlations"

Speaker-Embeddings-Correlation-Pooling This is the original implementation of the pooling method introduced in "Speaker embeddings by modeling channel

Themos Stafylakis 10 Apr 30, 2022
A fast and lightweight python-based CTC beam search decoder for speech recognition.

pyctcdecode A fast and feature-rich CTC beam search decoder for speech recognition written in Python, providing n-gram (kenlm) language model support

Kensho 315 Dec 21, 2022
Command Line Text-To-Speech using Google TTS

cli-tts Thanks to gTTS by @pndurette! This is an interactive command line text-to-speech tool using Google TTS. Just type text and the voice will be p

ReekyStive 3 Nov 11, 2022
A framework for implementing federated learning

This is partly the reproduction of the paper of [Privacy-Preserving Federated Learning in Fog Computing](DOI: 10.1109/JIOT.2020.2987958. 2020)

DavidChen 46 Sep 23, 2022
MASS: Masked Sequence to Sequence Pre-training for Language Generation

MASS: Masked Sequence to Sequence Pre-training for Language Generation

Microsoft 1.1k Dec 17, 2022
FactSumm: Factual Consistency Scorer for Abstractive Summarization

FactSumm: Factual Consistency Scorer for Abstractive Summarization FactSumm is a toolkit that scores Factualy Consistency for Abstract Summarization W

devfon 83 Jan 09, 2023
Arabic speech recognition, classification and text-to-speech.

klaam Arabic speech recognition, classification and text-to-speech using many advanced models like wave2vec and fastspeech2. This repository allows tr

ARBML 177 Dec 27, 2022
Using BERT-based models for toxic span detection

SemEval 2021 Task 5: Toxic Spans Detection: Task: Link to SemEval-2021: Task 5 Toxic Span Detection is https://competitions.codalab.org/competitions/2

Ravika Nagpal 1 Jan 04, 2022
Just a Basic like Language for Zeno INC

zeno-basic-language Just a Basic like Language for Zeno INC This is written in 100% python. this is basic language like language. so its not for big p

Voidy Devleoper 1 Dec 18, 2021
Code associated with the Don't Stop Pretraining ACL 2020 paper

dont-stop-pretraining Code associated with the Don't Stop Pretraining ACL 2020 paper Citation @inproceedings{dontstoppretraining2020, author = {Suchi

AI2 449 Jan 04, 2023
Code for the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"

T5: Text-To-Text Transfer Transformer The t5 library serves primarily as code for reproducing the experiments in Exploring the Limits of Transfer Lear

Google Research 4.6k Jan 01, 2023
中文无监督SimCSE Pytorch实现

A PyTorch implementation of unsupervised SimCSE SimCSE: Simple Contrastive Learning of Sentence Embeddings 1. 用法 无监督训练 python train_unsup.py ./data/ne

99 Dec 23, 2022