Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch

Overview

COCO LM Pretraining (wip)

Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra.

Install

$ pip install coco-lm-pytorch

Usage

An example using the x-transformers library

$ pip install x-transformers

Then

import torch
from coco_lm_pytorch import COCO

# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator

from x_transformers import TransformerWrapper, Encoder

generator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 256,         # smaller hidden dimension
        heads = 4,         # less heads
        ff_mult = 2,       # smaller feedforward dimension
        depth = 1
    )
)

discriminator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 1024,
        heads = 16,
        ff_mult = 4,
        depth = 12
    )
)

# (2) weight tie the token and positional embeddings of generator and discriminator

generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb

# weight tie any other embeddings if available, token type embeddings, etc.

# (3) instantiate COCO

trainer = COCO(
    generator,
    discriminator,
    discr_dim = 1024,            # the embedding dimension of the discriminator
    discr_layer = 'norm',        # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
    cls_token_id = 1,            # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning
    mask_token_id = 2,           # the token id reserved for masking
    pad_token_id = 0,            # the token id for padding
    mask_prob = 0.15,            # masking probability for masked language modeling
    mask_ignore_token_ids = [],  # ids of tokens to ignore for mask modeling ex. (cls, sep)
    cl_weight = 1.,              # weight for the contrastive learning loss
    disc_weight = 1.,            # weight for the corrective learning loss
    gen_weight = 1.              # weight for the MLM loss
)

# (4) train

data = torch.randint(0, 20000, (1, 1024))

loss = trainer(data)
loss.backward()

# after much training, the discriminator should have improved

torch.save(discriminator, f'./pretrained-model.pt')

Citations

@misc{meng2021cocolm,
    title   = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining}, 
    author  = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song},
    year    = {2021},
    eprint  = {2102.08473},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Big Bird: Transformers for Longer Sequences

BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. Moreover, BigBird comes along with a theoretical understanding of the capabilities of a complete transformer that the sparse model can handle.

Beyond Paragraphs: NLP for Long Sequences

Beyond Paragraphs: NLP for Long Sequences

Text-Summarization-using-NLP - Text Summarization using NLP  to fetch BBC News Article and summarize its text and also it includes custom article Summarization PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.
PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.

An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"

MILES is a multilingual text simplifier inspired by LSBert - A BERT-based lexical simplification approach proposed in 2018. Unlike LSBert, MILES uses the bert-base-multilingual-uncased model, as well as simple language-agnostic approaches to complex word identification (CWI) and candidate ranking. PyTorch implementation of the paper:  Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding
PyTorch implementation of the paper: Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding

Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding This repository contains the official PyTorch implementation of th

Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation.  This is part of the CASL project: http://casl-project.ai/
Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation. This is part of the CASL project: http://casl-project.ai/

Texar-PyTorch is a toolkit aiming to support a broad set of machine learning, especially natural language processing and text generation tasks. Texar

In this repository, I have developed an end to end Automatic speech recognition project. I have developed the neural network model for automatic speech recognition with PyTorch and used MLflow to manage the ML lifecycle, including experimentation, reproducibility, deployment, and a central model registry.
Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks

Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks. It takes raw videos/images + text as inputs, and outputs task predictions. ClipBERT is designed based on 2D CNNs and transformers, and uses a sparse sampling strategy to enable efficient end-to-end video-and-language learning.

Comments
  • Question about corrective LM loss

    Question about corrective LM loss

    Hi @lucidrains ,

    Thanks for your great repo!

    I looked at your code: coco_lm_pytorch.py. I see there are three losses in line 242. weighted_loss = self.cl_weight * cl_loss + self.gen_weight * mlm_loss + self.disc_weight * disc_loss

    cl_loss is the contrastive loss, mlm_loss is the loss of the auxiliary generator, and disc_loss is the loss of binary discrimination. I wonder where the LM loss of corrective language modeling loss is. Could you point me?

    Best, Abdul.

    opened by elmadany 0
  • What can v0.0.2 do?

    What can v0.0.2 do?

    I'm quite excited to give COCO-LM a try! Thanks as always for the great speedy open source repo @lucidrains .

    Quick question: has this repository been tried on real data, and if so - loosely what type of setup? Trying to figure out whether jumping in coco-lm-pytorch I should have the expectation of being a first beta-tester, or I'm looking at something that is already stable. Thanks!

    opened by dginev 0
Releases(0.0.2)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
keras implement of transformers for humans

keras implement of transformers for humans

苏剑林(Jianlin Su) 4.8k Jan 03, 2023
A Facebook Messenger Chatbot using NLP

A Facebook Messenger Chatbot using NLP This project is about creating a messenger chatbot using basic NLP techniques and models like Logistic Regressi

6 Nov 20, 2022
The Classical Language Toolkit

Notice: This Git branch (dev) contains the CLTK's upcoming major release (v. 1.0.0). See https://github.com/cltk/cltk/tree/master and https://docs.clt

Classical Language Toolkit 754 Jan 09, 2023
Estimation of the CEFR complexity score of a given word, sentence or text.

NLP-Swedish … allows to estimate CEFR (Common European Framework of References) complexity score of a given word, sentence or text. CEFR scores come f

3 Apr 30, 2022
Research Code for NeurIPS 2020 Spotlight paper "Large-Scale Adversarial Training for Vision-and-Language Representation Learning": UNITER adversarial training part

VILLA: Vision-and-Language Adversarial Training This is the official repository of VILLA (NeurIPS 2020 Spotlight). This repository currently supports

Zhe Gan 109 Dec 31, 2022
HuggingSound: A toolkit for speech-related tasks based on HuggingFace's tools

HuggingSound HuggingSound: A toolkit for speech-related tasks based on HuggingFace's tools. I have no intention of building a very complex tool here.

Jonatas Grosman 247 Dec 26, 2022
ProtFeat is protein feature extraction tool that utilizes POSSUM and iFeature.

Description: ProtFeat is designed to extract the protein features by employing POSSUM and iFeature python-based tools. ProtFeat includes a total of 39

GOKHAN OZSARI 5 Dec 16, 2022
Official source for spanish Language Models and resources made @ BSC-TEMU within the "Plan de las Tecnologías del Lenguaje" (Plan-TL).

Spanish Language Models 💃🏻 A repository part of the MarIA project. Corpora 📃 Corpora Number of documents Number of tokens Size (GB) BNE 201,080,084

Plan de Tecnologías del Lenguaje - Gobierno de España 203 Dec 20, 2022
This code extends the neural style transfer image processing technique to video by generating smooth transitions between several reference style images

Neural Style Transfer Transition Video Processing By Brycen Westgarth and Tristan Jogminas Description This code extends the neural style transfer ima

Brycen Westgarth 110 Jan 07, 2023
🏆 • 5050 most frequent words in 109 languages

🏆 Most Common Words Multilingual 5000 most frequent words in 109 languages. Uses wordfrequency.info as a source. 🔗 License source code license data

14 Nov 24, 2022
Searching keywords in PDF file folders

keyword_searching Steps to use this Python scripts: (1)Paste this script into the file folder containing the PDF files you need to search from; (2)Thi

1 Nov 08, 2021
Python api wrapper for JellyFish Lights

Python api wrapper for JellyFish Lights The hope is to make this a pip installable package Current capabalilities: Connects to a local JellyFish Light

10 Dec 18, 2022
MMDA - multimodal document analysis

MMDA - multimodal document analysis

AI2 75 Jan 04, 2023
A demo for end-to-end English and Chinese text spotting using ABCNet.

ABCNet_Chinese A demo for end-to-end English and Chinese text spotting using ABCNet. This is an old model that was trained a long ago, which serves as

Yuliang Liu 45 Oct 04, 2022
Simple GUI where you can enter an article and get a crisp summarized version.

Text-Summarization-using-TextRank-BART Simple GUI where you can enter an article and get a crisp summarized version. How to run: Clone the repo Instal

Rohit P 4 Sep 28, 2022
Auto translate textbox from Japanese to English or Indonesia

priconne-auto-translate Auto translate textbox from Japanese to English or Indonesia How to use Install python first, Anaconda is recommended Install

Aji Priyo Wibowo 5 Aug 25, 2022
A paper list for aspect based sentiment analysis.

Aspect-Based-Sentiment-Analysis A paper list for aspect based sentiment analysis. Survey [IEEE-TAC-20]: Issues and Challenges of Aspect-based Sentimen

jiangqn 419 Dec 20, 2022
Words-per-minute - A terminal app written in python utilizing the curses module that tests the user's ability to type

words-per-minute A terminal app written in python utilizing the curses module th

Tanim Islam 1 Jan 14, 2022
Code for "Finetuning Pretrained Transformers into Variational Autoencoders"

transformers-into-vaes Code for Finetuning Pretrained Transformers into Variational Autoencoders (our submission to NLP Insights Workshop 2021). Gathe

Seongmin Park 22 Nov 26, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained mo

Hugging Face 77.2k Jan 03, 2023