使用pytorch+transformers复现了SimCSE论文中的有监督训练和无监督训练方法

Related tags

Text Data & NLPSimCSE
Overview

SimCSE复现

项目描述

SimCSE是一种简单但是很巧妙的NLP对比学习方法,创新性地引入Dropout的方式,对样本添加噪声,从而达到对正样本增强的目的。 该框架的训练目的为:对于batch中的每个样本,拉近其与正样本之间的距离,拉远其与负样本之间的距离,使得模型能够在大规模无监督语料(也可以使用有监督的语料)中学习到文本相似关系。 详见论文:Simple Contrastive Learning of Sentence EmbeddingsSimCSE官方代码仓库

本项目使用pytorch+transformers复现了SimCSE论文中的有监督训练和无监督训练方法,并且在STS-B数据集上进行消融实验,评价指标为Spearman相关系数,预训练模型为Bert-base-uncased, 验证了SimCSE的有效性。在STS-B数据集上,有监督训练和无监督训练的复现效果如下表。

在无监督训练中,dropout=0.1时,复现效果比原文略差,但也比较接近。当dropout=0.2时,复现效果比原文略高。 ** 但在有监督训练中,不知是否由于batch size过小(原论文使用512),复现效果与论文的效果相差较远,后续会进行排查。 **

训练方法 learning rate batch size dropout Spearman’s correlation
原论文 无监督 3e-5 64 0.1 0.763
复现 无监督 3e-5 64 0.2 0.771
复现 无监督 3e-5 64 0.1 0.748
原论文 有监督 5e-5 512 0.1 0.816
复现 有监督 5e-5 64 0.1 0.764

运行环境

python==3.6、transformers==3.1.0、torch==1.6.0

项目结构

  • data:存放训练数据
    • stsbenchmark:STS-B数据集
      • sts-dev.csv:STS-B验证集
      • sts-test.csv:STS-B验测试集
    • nli_for_simcse.csv:数量275601为的NLI数据集
    • wiki1m_for_simcse.txt:维基百科上获取的100w的文本
  • output:输出目录
  • pretrain_model:预训练模型存放位置
  • script:脚本存放位置。
  • dataset.py
  • model.py:模型代码,包含有监督和无监督损失函数的计算方式
  • train.py:训练代码

使用方法

Quick Start

下载训练数据:

bash script/download_nli.sh
bash script/download_wiki.sh

无监督训练,运行脚本

bash script/run_unsup_train.sh

有监督训练,运行脚本

bash script/run_sup_train.sh

实验

无监督训练

从前四条实验数据中可以看到,较大的batch size在一定程度上可以增加模型的泛化性。

dropout为0.2的时候,训练效果比0.1与0.3更好,有可能dropout=0.1加入的噪声过小,而dropout=0.3加入的噪声过大,增强得到的样本与原始样本差异较大。

learning rate batch size dropout 在哪一步得到best checkpoint 验证集上的得分 测试集上的得分
3e-5 256 0.1 6000 0.800 0.761
3e-5 128 0.1 4200 0.799 0.747
3e-5 64 0.1 10900 0.803 0.748
3e-5 32 0.1 21300 0.787 0.714
3e-5 64 0.2 11200 0.811 0.771
3e-5 64 0.3 6300 0.781 0.745
1e-5 64 0.1 16400 0.798 0.751

有监督训练

有监督实验的复现结果未达到预期,超参数相同时,在验证集上的得分略高于无监督,但是在测试集上,得分基本没有差异。增大有监督训练的学习率,有监督的训练的得分略高于无监督训练, 但还是与论文声称的0.816相差较远,原论文使用512的batch size, 不知是否由于batch size的设置有关,后续会对有监督的训练代码进一步排查。

不过从训练曲线可以看到,有监督训练的收敛速度明显快于无监督训练,这也符合我们的认知。

训练方法 learning rate batch size dropout 在哪一步得到best checkpoint 验证集上的得分 测试集上的得分
无监督 3e-5 64 0.1 10900 0.803 0.748
有监督 3e-5 64 0.1 200 0.810 0.748
有监督 5e-5 64 0.1 2300 0.809 0.764
有监督 3e-5 32 0.1 200 0.808 0.743
有监督 5e-5 32 0.1 200 0.806 0.746

无监督训练过程中,验证集得分的变化曲线: avatar

有监督训练过程中,验证集得分的变化曲线: avatar

REFERENCE

TODO

  • 排查有监督学习的效果不符合预期的原因
Machine learning classifiers to predict American Sign Language .

ASL-Classifiers American Sign Language (ASL) is a natural language that serves as the predominant sign language of Deaf communities in the United Stat

Tarek idrees 0 Feb 08, 2022
Words_And_Phrases - Just a repo for useful words and phrases that might come handy in some scenarios. Feel free to add yours

Words_And_Phrases Just a repo for useful words and phrases that might come handy in some scenarios. Feel free to add yours Abbreviations Abbreviation

Subhadeep Mandal 1 Feb 01, 2022
A python script that will use hydra to get user and password to login to ssh, ftp, and telnet

Hydra-Auto-Hack A python script that will use hydra to get user and password to login to ssh, ftp, and telnet Project Description This python script w

2 Jan 16, 2022
A Chinese to English Neural Model Translation Project

ZH-EN NMT Chinese to English Neural Machine Translation This project is inspired by Stanford's CS224N NMT Project Dataset used in this project: News C

Zhenbang Feng 29 Nov 26, 2022
End-to-end image captioning with EfficientNet-b3 + LSTM with Attention

Image captioning End-to-end image captioning with EfficientNet-b3 + LSTM with Attention Model is seq2seq model. In the encoder pretrained EfficientNet

2 Feb 10, 2022
Concept Modeling: Topic Modeling on Images and Text

Concept is a technique that leverages CLIP and BERTopic-based techniques to perform Concept Modeling on images.

Maarten Grootendorst 120 Dec 27, 2022
In this Notebook I've build some machine-learning and deep-learning to classify corona virus tweets, in both multi class classification and binary classification.

Hello, This Notebook Contains Example of Corona Virus Tweets Multi Class Classification. - Classes is: Extremely Positive, Positive, Extremely Negativ

Khaled Tofailieh 3 Dec 06, 2022
Pipelines de datos, 2021.

Este repo ilustra un proceso sencillo de automatización de transformación y modelado de datos, a través de un pipeline utilizando Luigi. Stack princip

Rodolfo Ferro 8 May 19, 2022
An Analysis Toolkit for Natural Language Generation (Translation, Captioning, Summarization, etc.)

VizSeq is a Python toolkit for visual analysis on text generation tasks like machine translation, summarization, image captioning, speech translation

Facebook Research 409 Oct 28, 2022
BERT-based Financial Question Answering System

BERT-based Financial Question Answering System In this example, we use Jina, PyTorch, and Hugging Face transformers to build a production-ready BERT-b

Bithiah Yuan 61 Sep 18, 2022
Code to reproduce the results of the paper 'Towards Realistic Few-Shot Relation Extraction' (EMNLP 2021)

Realistic Few-Shot Relation Extraction This repository contains code to reproduce the results in the paper "Towards Realistic Few-Shot Relation Extrac

Bloomberg 8 Nov 09, 2022
This repository is home to the Optimus data transformation plugins for various data processing needs.

Transformers Optimus's transformation plugins are implementations of Task and Hook interfaces that allows execution of arbitrary jobs in optimus. To i

Open Data Platform 37 Dec 14, 2022
L3Cube-MahaCorpus a Marathi monolingual data set scraped from different internet sources.

L3Cube-MahaCorpus L3Cube-MahaCorpus a Marathi monolingual data set scraped from different internet sources. We expand the existing Marathi monolingual

21 Dec 17, 2022
Code for "Generative adversarial networks for reconstructing natural images from brain activity".

Reconstruct handwritten characters from brains using GANs Example code for the paper "Generative adversarial networks for reconstructing natural image

K. Seeliger 2 May 17, 2022
Training code of Spatial Time Memory Network. Semi-supervised video object segmentation.

Training-code-of-STM This repository fully reproduces Space-Time Memory Networks Performance on Davis17 val set&Weights backbone training stage traini

haochen wang 128 Dec 11, 2022
Code from the paper "High-Performance Brain-to-Text Communication via Handwriting"

Code from the paper "High-Performance Brain-to-Text Communication via Handwriting"

Francis R. Willett 305 Dec 22, 2022
A flask application to predict the speech emotion of any .wav file.

This is a speech emotion recognition app. It will allow you to train a modular MLP model with the RAVDESS dataset, and then use that model with a flask application to predict the speech emotion of an

Aryan Vijaywargia 2 Dec 15, 2021
Modified GPT using average pooling to reduce the softmax attention memory constraints.

NLP-GPT-Upsampling This repository contains an implementation of Open AI's GPT Model. In particular, this implementation takes inspiration from the Ny

WD 1 Dec 03, 2021
fastNLP: A Modularized and Extensible NLP Framework. Currently still in incubation.

fastNLP fastNLP是一款轻量级的自然语言处理(NLP)工具包,目标是快速实现NLP任务以及构建复杂模型。 fastNLP具有如下的特性: 统一的Tabular式数据容器,简化数据预处理过程; 内置多种数据集的Loader和Pipe,省去预处理代码; 各种方便的NLP工具,例如Embedd

fastNLP 2.8k Jan 01, 2023
FewCLUE: 为中文NLP定制的小样本学习测评基准

FewCLUE: 为中文NLP定制的小样本学习测评基准

CLUE benchmark 387 Jan 04, 2023