使用pytorch+transformers复现了SimCSE论文中的有监督训练和无监督训练方法

Related tags

Text Data & NLPSimCSE
Overview

SimCSE复现

项目描述

SimCSE是一种简单但是很巧妙的NLP对比学习方法,创新性地引入Dropout的方式,对样本添加噪声,从而达到对正样本增强的目的。 该框架的训练目的为:对于batch中的每个样本,拉近其与正样本之间的距离,拉远其与负样本之间的距离,使得模型能够在大规模无监督语料(也可以使用有监督的语料)中学习到文本相似关系。 详见论文:Simple Contrastive Learning of Sentence EmbeddingsSimCSE官方代码仓库

本项目使用pytorch+transformers复现了SimCSE论文中的有监督训练和无监督训练方法,并且在STS-B数据集上进行消融实验,评价指标为Spearman相关系数,预训练模型为Bert-base-uncased, 验证了SimCSE的有效性。在STS-B数据集上,有监督训练和无监督训练的复现效果如下表。

在无监督训练中,dropout=0.1时,复现效果比原文略差,但也比较接近。当dropout=0.2时,复现效果比原文略高。 ** 但在有监督训练中,不知是否由于batch size过小(原论文使用512),复现效果与论文的效果相差较远,后续会进行排查。 **

训练方法 learning rate batch size dropout Spearman’s correlation
原论文 无监督 3e-5 64 0.1 0.763
复现 无监督 3e-5 64 0.2 0.771
复现 无监督 3e-5 64 0.1 0.748
原论文 有监督 5e-5 512 0.1 0.816
复现 有监督 5e-5 64 0.1 0.764

运行环境

python==3.6、transformers==3.1.0、torch==1.6.0

项目结构

  • data:存放训练数据
    • stsbenchmark:STS-B数据集
      • sts-dev.csv:STS-B验证集
      • sts-test.csv:STS-B验测试集
    • nli_for_simcse.csv:数量275601为的NLI数据集
    • wiki1m_for_simcse.txt:维基百科上获取的100w的文本
  • output:输出目录
  • pretrain_model:预训练模型存放位置
  • script:脚本存放位置。
  • dataset.py
  • model.py:模型代码,包含有监督和无监督损失函数的计算方式
  • train.py:训练代码

使用方法

Quick Start

下载训练数据:

bash script/download_nli.sh
bash script/download_wiki.sh

无监督训练,运行脚本

bash script/run_unsup_train.sh

有监督训练,运行脚本

bash script/run_sup_train.sh

实验

无监督训练

从前四条实验数据中可以看到,较大的batch size在一定程度上可以增加模型的泛化性。

dropout为0.2的时候,训练效果比0.1与0.3更好,有可能dropout=0.1加入的噪声过小,而dropout=0.3加入的噪声过大,增强得到的样本与原始样本差异较大。

learning rate batch size dropout 在哪一步得到best checkpoint 验证集上的得分 测试集上的得分
3e-5 256 0.1 6000 0.800 0.761
3e-5 128 0.1 4200 0.799 0.747
3e-5 64 0.1 10900 0.803 0.748
3e-5 32 0.1 21300 0.787 0.714
3e-5 64 0.2 11200 0.811 0.771
3e-5 64 0.3 6300 0.781 0.745
1e-5 64 0.1 16400 0.798 0.751

有监督训练

有监督实验的复现结果未达到预期,超参数相同时,在验证集上的得分略高于无监督,但是在测试集上,得分基本没有差异。增大有监督训练的学习率,有监督的训练的得分略高于无监督训练, 但还是与论文声称的0.816相差较远,原论文使用512的batch size, 不知是否由于batch size的设置有关,后续会对有监督的训练代码进一步排查。

不过从训练曲线可以看到,有监督训练的收敛速度明显快于无监督训练,这也符合我们的认知。

训练方法 learning rate batch size dropout 在哪一步得到best checkpoint 验证集上的得分 测试集上的得分
无监督 3e-5 64 0.1 10900 0.803 0.748
有监督 3e-5 64 0.1 200 0.810 0.748
有监督 5e-5 64 0.1 2300 0.809 0.764
有监督 3e-5 32 0.1 200 0.808 0.743
有监督 5e-5 32 0.1 200 0.806 0.746

无监督训练过程中,验证集得分的变化曲线: avatar

有监督训练过程中,验证集得分的变化曲线: avatar

REFERENCE

TODO

  • 排查有监督学习的效果不符合预期的原因
Anomaly Detection 이상치 탐지 전처리 모듈

Anomaly Detection 시계열 데이터에 대한 이상치 탐지 1. Kernel Density Estimation을 활용한 이상치 탐지 train_data_path와 test_data_path에 존재하는 시점 정보를 포함하고 있는 csv 형태의 train data와

CLUST-consortium 43 Nov 28, 2022
Random Directed Acyclic Graph Generator

DAG_Generator Random Directed Acyclic Graph Generator verison1.0 简介 工作流通常由DAG(有向无环图)来定义,其中每个计算任务$T_i$由一个顶点(node,task,vertex)表示。同时,任务之间的每个数据或控制依赖性由一条加权

Livion 17 Dec 27, 2022
skweak: A software toolkit for weak supervision applied to NLP tasks

Labelled data remains a scarce resource in many practical NLP scenarios. This is especially the case when working with resource-poor languages (or text domains), or when using task-specific labels wi

Norsk Regnesentral (Norwegian Computing Center) 850 Dec 28, 2022
Honor's thesis project analyzing whether the GPT-2 model can more effectively generate free-verse or structured poetry.

gpt2-poetry The following code is for my senior honor's thesis project, under the guidance of Dr. Keith Holyoak at the University of California, Los A

Ashley Kim 2 Jan 09, 2022
Code for the ACL 2021 paper "Structural Guidance for Transformer Language Models"

Structural Guidance for Transformer Language Models This repository accompanies the paper, Structural Guidance for Transformer Language Models, publis

International Business Machines 10 Dec 14, 2022
Saptak Bhoumik 14 May 24, 2022
2021 AI CUP Competition on Traditional Chinese Scene Text Recognition - Intermediate Contest

繁體中文場景文字辨識 程式碼說明 組別:這就是我 成員:蔣明憲 唐碩謙 黃玥菱 林冠霆 蕭靖騰 目錄 環境套件 安裝方式 資料夾布局 前處理-製作偵測訓練註解檔 前處理-製作分類訓練樣本 part.py : 從 json 裁切出分類訓練樣本 Class.py : 將切出來的樣本按照文字分類到各資料夾

HuanyueTW 3 Jan 14, 2022
A Word Level Transformer layer based on PyTorch and 🤗 Transformers.

Transformer Embedder A Word Level Transformer layer based on PyTorch and 🤗 Transformers. How to use Install the library from PyPI: pip install transf

Riccardo Orlando 27 Nov 20, 2022
An example project using OpenPrompt under pytorch-lightning for prompt-based SST2 sentiment analysis model

pl_prompt_sst An example project using OpenPrompt under the framework of pytorch-lightning for a training prompt-based text classification model on SS

Zhiling Zhang 5 Oct 21, 2022
Text classification on IMDB dataset using Keras and Bi-LSTM network

Text classification on IMDB dataset using Keras and Bi-LSTM Text classification on IMDB dataset using Keras and Bi-LSTM network. Usage python3 main.py

Hamza Rashid 2 Sep 27, 2022
Visual Automata is a Python 3 library built as a wrapper for Caleb Evans' Automata library to add more visualization features.

Visual Automata Copyright 2021 Lewi Lie Uberg Released under the MIT license Visual Automata is a Python 3 library built as a wrapper for Caleb Evans'

Lewi Uberg 55 Nov 17, 2022
ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

ELECTRA Introduction ELECTRA is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using

Google Research 2.1k Dec 28, 2022
This repository contains the code for "Exploiting Cloze Questions for Few-Shot Text Classification and Natural Language Inference"

Pattern-Exploiting Training (PET) This repository contains the code for Exploiting Cloze Questions for Few-Shot Text Classification and Natural Langua

Timo Schick 1.4k Dec 30, 2022
End-to-end image captioning with EfficientNet-b3 + LSTM with Attention

Image captioning End-to-end image captioning with EfficientNet-b3 + LSTM with Attention Model is seq2seq model. In the encoder pretrained EfficientNet

2 Feb 10, 2022
A look-ahead multi-entity Transformer for modeling coordinated agents.

baller2vec++ This is the repository for the paper: Michael A. Alcorn and Anh Nguyen. baller2vec++: A Look-Ahead Multi-Entity Transformer For Modeling

Michael A. Alcorn 30 Dec 16, 2022
Unet-TTS: Improving Unseen Speaker and Style Transfer in One-shot Voice Cloning

Unet-TTS: Improving Unseen Speaker and Style Transfer in One-shot Voice Cloning English | 中文 ❗ Now we provide inferencing code and pre-training models

164 Jan 02, 2023
Transformer-based Text Auto-encoder (T-TA) using TensorFlow 2.

T-TA (Transformer-based Text Auto-encoder) This repository contains codes for Transformer-based Text Auto-encoder (T-TA, paper: Fast and Accurate Deep

Jeong Ukjae 13 Dec 13, 2022
A python gui program to generate reddit text to speech videos from the id of any post.

Reddit text to speech generator A python gui program to generate reddit text to speech videos from the id of any post. Current functionality Generate

Aadvik 17 Dec 19, 2022
Named-entity recognition using neural networks. Easy-to-use and state-of-the-art results.

NeuroNER NeuroNER is a program that performs named-entity recognition (NER). Website: neuroner.com. This page gives step-by-step instructions to insta

Franck Dernoncourt 1.6k Dec 27, 2022
Code for CodeT5: a new code-aware pre-trained encoder-decoder model.

CodeT5: Identifier-aware Unified Pre-trained Encoder-Decoder Models for Code Understanding and Generation This is the official PyTorch implementation

Salesforce 564 Jan 08, 2023