Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch

Overview

COCO LM Pretraining (wip)

Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra.

Install

$ pip install coco-lm-pytorch

Usage

An example using the x-transformers library

$ pip install x-transformers

Then

import torch
from coco_lm_pytorch import COCO

# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator

from x_transformers import TransformerWrapper, Encoder

generator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 256,         # smaller hidden dimension
        heads = 4,         # less heads
        ff_mult = 2,       # smaller feedforward dimension
        depth = 1
    )
)

discriminator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 1024,
        heads = 16,
        ff_mult = 4,
        depth = 12
    )
)

# (2) weight tie the token and positional embeddings of generator and discriminator

generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb

# weight tie any other embeddings if available, token type embeddings, etc.

# (3) instantiate COCO

trainer = COCO(
    generator,
    discriminator,
    discr_dim = 1024,            # the embedding dimension of the discriminator
    discr_layer = 'norm',        # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
    cls_token_id = 1,            # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning
    mask_token_id = 2,           # the token id reserved for masking
    pad_token_id = 0,            # the token id for padding
    mask_prob = 0.15,            # masking probability for masked language modeling
    mask_ignore_token_ids = [],  # ids of tokens to ignore for mask modeling ex. (cls, sep)
    cl_weight = 1.,              # weight for the contrastive learning loss
    disc_weight = 1.,            # weight for the corrective learning loss
    gen_weight = 1.              # weight for the MLM loss
)

# (4) train

data = torch.randint(0, 20000, (1, 1024))

loss = trainer(data)
loss.backward()

# after much training, the discriminator should have improved

torch.save(discriminator, f'./pretrained-model.pt')

Citations

@misc{meng2021cocolm,
    title   = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining}, 
    author  = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song},
    year    = {2021},
    eprint  = {2102.08473},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Big Bird: Transformers for Longer Sequences

BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. Moreover, BigBird comes along with a theoretical understanding of the capabilities of a complete transformer that the sparse model can handle.

Beyond Paragraphs: NLP for Long Sequences

Beyond Paragraphs: NLP for Long Sequences

Text-Summarization-using-NLP - Text Summarization using NLP  to fetch BBC News Article and summarize its text and also it includes custom article Summarization PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.
PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.

An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"

MILES is a multilingual text simplifier inspired by LSBert - A BERT-based lexical simplification approach proposed in 2018. Unlike LSBert, MILES uses the bert-base-multilingual-uncased model, as well as simple language-agnostic approaches to complex word identification (CWI) and candidate ranking. PyTorch implementation of the paper:  Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding
PyTorch implementation of the paper: Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding

Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding This repository contains the official PyTorch implementation of th

Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation.  This is part of the CASL project: http://casl-project.ai/
Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation. This is part of the CASL project: http://casl-project.ai/

Texar-PyTorch is a toolkit aiming to support a broad set of machine learning, especially natural language processing and text generation tasks. Texar

In this repository, I have developed an end to end Automatic speech recognition project. I have developed the neural network model for automatic speech recognition with PyTorch and used MLflow to manage the ML lifecycle, including experimentation, reproducibility, deployment, and a central model registry.
Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks

Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks. It takes raw videos/images + text as inputs, and outputs task predictions. ClipBERT is designed based on 2D CNNs and transformers, and uses a sparse sampling strategy to enable efficient end-to-end video-and-language learning.

Comments
  • Question about corrective LM loss

    Question about corrective LM loss

    Hi @lucidrains ,

    Thanks for your great repo!

    I looked at your code: coco_lm_pytorch.py. I see there are three losses in line 242. weighted_loss = self.cl_weight * cl_loss + self.gen_weight * mlm_loss + self.disc_weight * disc_loss

    cl_loss is the contrastive loss, mlm_loss is the loss of the auxiliary generator, and disc_loss is the loss of binary discrimination. I wonder where the LM loss of corrective language modeling loss is. Could you point me?

    Best, Abdul.

    opened by elmadany 0
  • What can v0.0.2 do?

    What can v0.0.2 do?

    I'm quite excited to give COCO-LM a try! Thanks as always for the great speedy open source repo @lucidrains .

    Quick question: has this repository been tried on real data, and if so - loosely what type of setup? Trying to figure out whether jumping in coco-lm-pytorch I should have the expectation of being a first beta-tester, or I'm looking at something that is already stable. Thanks!

    opened by dginev 0
Releases(0.0.2)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
💥 Fast State-of-the-Art Tokenizers optimized for Research and Production

Provides an implementation of today's most used tokenizers, with a focus on performance and versatility. Main features: Train new vocabularies and tok

Hugging Face 6.2k Dec 31, 2022
ConferencingSpeech2022; Non-intrusive Objective Speech Quality Assessment (NISQA) Challenge

ConferencingSpeech 2022 challenge This repository contains the datasets list and scripts required for the ConferencingSpeech 2022 challenge. For more

21 Dec 02, 2022
Host your own GPT-3 Discord bot

GPT3 Discord Bot Host your own GPT-3 Discord bot i'd host and make the bot invitable myself, however GPT3 terms of service prohibit public use of GPT3

[something hillarious here] 8 Jan 07, 2023
Stanford CoreNLP provides a set of natural language analysis tools written in Java

Stanford CoreNLP Stanford CoreNLP provides a set of natural language analysis tools written in Java. It can take raw human language text input and giv

Stanford NLP 8.8k Jan 07, 2023
What are the best Systems? New Perspectives on NLP Benchmarking

What are the best Systems? New Perspectives on NLP Benchmarking In Machine Learning, a benchmark refers to an ensemble of datasets associated with one

Pierre Colombo 12 Nov 03, 2022
ASCEND Chinese-English code-switching dataset

ASCEND (A Spontaneous Chinese-English Dataset) introduces a high-quality resource of spontaneous multi-turn conversational dialogue Chinese-English code-switching corpus collected in Hong Kong.

CAiRE 11 Dec 09, 2022
Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra. What is Lightning Tran

Pytorch Lightning 581 Dec 21, 2022
This repository contains the code, models and datasets discussed in our paper "Few-Shot Question Answering by Pretraining Span Selection"

Splinter This repository contains the code, models and datasets discussed in our paper "Few-Shot Question Answering by Pretraining Span Selection", to

Ori Ram 88 Dec 31, 2022
Chinese Pre-Trained Language Models (CPM-LM) Version-I

CPM-Generate 为了促进中文自然语言处理研究的发展,本项目提供了 CPM-LM (2.6B) 模型的文本生成代码,可用于文本生成的本地测试,并以此为基础进一步研究零次学习/少次学习等场景。[项目首页] [模型下载] [技术报告] 若您想使用CPM-1进行推理,我们建议使用高效推理工具BMI

Tsinghua AI 1.4k Jan 03, 2023
pytorch implementation of Attention is all you need

A Pytorch Implementation of the Transformer: Attention Is All You Need Our implementation is largely based on Tensorflow implementation Requirements N

230 Dec 07, 2022
Code for paper "Which Training Methods for GANs do actually Converge? (ICML 2018)"

GAN stability This repository contains the experiments in the supplementary material for the paper Which Training Methods for GANs do actually Converg

Lars Mescheder 884 Nov 11, 2022
HuggingSound: A toolkit for speech-related tasks based on HuggingFace's tools

HuggingSound HuggingSound: A toolkit for speech-related tasks based on HuggingFace's tools. I have no intention of building a very complex tool here.

Jonatas Grosman 247 Dec 26, 2022
Simple Python script to scrape youtube channles of "Parity Technologies and Web3 Foundation" and translate them to well-known braille language or any language

Simple Python script to scrape youtube channles of "Parity Technologies and Web3 Foundation" and translate them to well-known braille language or any

Little Endian 1 Apr 28, 2022
LSTM model - IMDB review sentiment analysis

NLP - Movie review sentiment analysis The colab notebook contains the code for building a LSTM Recurrent Neural Network that gives 87-88% accuracy on

Sundeep Bhimireddy 1 Jan 29, 2022
Interpretable Models for NLP using PyTorch

This repo is deprecated. Please find the updated package here. https://github.com/EdGENetworks/anuvada Anuvada: Interpretable Models for NLP using PyT

Sandeep Tammu 19 Dec 17, 2022
A paper list for aspect based sentiment analysis.

Aspect-Based-Sentiment-Analysis A paper list for aspect based sentiment analysis. Survey [IEEE-TAC-20]: Issues and Challenges of Aspect-based Sentimen

jiangqn 419 Dec 20, 2022
Create a machine learning model which will predict if the mortgage will be approved or not based on 5 variables

Mortgage-Application-Analysis Create a machine learning model which will predict if the mortgage will be approved or not based on 5 variables: age, in

1 Jan 29, 2022
A combination of autoregressors and autoencoders using XLNet for sentiment analysis

A combination of autoregressors and autoencoders using XLNet for sentiment analysis Abstract In this paper sentiment analysis has been performed in or

James Zaridis 2 Nov 20, 2021
Tensorflow Implementation of A Generative Flow for Text-to-Speech via Monotonic Alignment Search

Tensorflow Implementation of A Generative Flow for Text-to-Speech via Monotonic Alignment Search

Ankur Dhuriya 10 Oct 13, 2022