The source codes for ACL 2021 paper 'BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data'

Overview

BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data

This repository provides the implementation details for the ACL 2021 main conference paper:

BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data. [paper]

1. Data Preparation

In this work, we carried out persona-based dialogue generation experiments under a persona-dense scenario (English PersonaChat) and a persona-sparse scenario (Chinese PersonalDialog), with the assistance of a series of auxiliary inference datasets. Here we summarize the key information of these datasets and provide the links to download these datasets if they are directly accessible.

2. How to Run

The setup.sh script contains the necessary dependencies to run this project. Simply run ./setup.sh would install these dependencies. Here we take the English PersonaChat dataset as an example to illustrate how to run the dialogue generation experiments. Generally, there are three steps, i.e., tokenization, training and inference:

  • Preprocessing

     python preprocess.py --dataset_type convai2 \
     --trainset ./data/ConvAI2/train_self_original_no_cands.txt \
     --testset ./data/ConvAI2/valid_self_original_no_cands.txt \
     --nliset ./data/ConvAI2/ \
     --encoder_model_name_or_path ./pretrained_models/bert/bert-base-uncased/ \
     --max_source_length 64 \
     --max_target_length 32
    

    We have provided some data examples (dozens of lines) at the ./data directory to show the data format. preprocess.py reads different datasets and tokenizes the raw data into a series of vocab IDs to facilitate model training. The --dataset_type could be either convai2 (for English PersonaChat) or ecdt2019 (for Chinese PersonalDialog). Finally, the tokenized data will be saved as a series of JSON files.

  • Model Training

     CUDA_VISIBLE_DEVICES=0 python bertoverbert.py --do_train \
     --encoder_model ./pretrained_models/bert/bert-base-uncased/ \
     --decoder_model ./pretrained_models/bert/bert-base-uncased/ \
     --decoder2_model ./pretrained_models/bert/bert-base-uncased/ \
     --save_model_path checkpoints/ConvAI2/bertoverbert --dataset_type convai2 \
     --dumped_token ./data/ConvAI2/convai2_tokenized/ \
     --learning_rate 7e-6 \
     --batch_size 32
    

    Here we initialize encoder and both decoders from the same downloaded BERT checkpoint. And more parameter settings could be found at bertoverbert.py.

  • Evaluations

     CUDA_VISIBLE_DEVICES=0 python bertoverbert.py --dumped_token ./data/ConvAI2/convai2_tokenized/ \
     --dataset_type convai2 \
     --encoder_model ./pretrained_models/bert/bert-base-uncased/  \
     --do_evaluation --do_predict \
     --eval_epoch 7
    

    Empirically, in the PersonaChat experiment with default hyperparameter settings, the best-performing checkpoint should be found between epoch 5 and epoch 9. If the training procedure goes fine, there should be some results like:

     Perplexity on test set is 21.037 and 7.813.
    

    where 21.037 is the ppl from the first decoder and 7.813 is the final ppl from the second decoder. And the generated results is redirected to test_result.tsv, here is a generated example from the above checkpoint:

     persona:i'm terrified of scorpions. i am employed by the us postal service. i've a german shepherd named barnaby. my father drove a car for nascar.
     query:sorry to hear that. my dad is an army soldier.
     gold:i thank him for his service.
     response_from_d1:that's cool. i'm a train driver.
     response_from_d2:that's cool. i'm a bit of a canadian who works for america.  
    

    where d1 and d2 are the two BERT decoders, respectively.

  • Computing Infrastructure:

    • The released codes were tested on NVIDIA Tesla V100 32G and NVIDIA PCIe A100 40G GPUs. Notice that with a batch_size=32, the BoB model will need at least 20Gb GPU resources for training.

MISC

  • Build upon 🤗 Transformers.

  • Bibtex:

      @inproceedings{song-etal-2021-bob,
          title = "BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data",
          author = "Haoyu Song, Yan Wang, Kaiyan Zhang, Wei-Nan Zhang, Ting Liu",
          booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics (ACL-2021)",
          month = "Aug",
          year = "2021",
          address = "Online",
          publisher = "Association for Computational Linguistics",
      }
      
  • Email: [email protected].

MT-GAN-PyTorch - PyTorch Implementation of Learning to Transfer: Unsupervised Domain Translation via Meta-Learning

MT-GAN-PyTorch PyTorch Implementation of AAAI-2020 Paper "Learning to Transfer: Unsupervised Domain Translation via Meta-Learning" Dependency: Python

29 Oct 19, 2022
Self-supervised Augmentation Consistency for Adapting Semantic Segmentation (CVPR 2021)

Self-supervised Augmentation Consistency for Adapting Semantic Segmentation This repository contains the official implementation of our paper: Self-su

Visual Inference Lab @TU Darmstadt 132 Dec 21, 2022
the code of the paper: Recurrent Multi-view Alignment Network for Unsupervised Surface Registration (CVPR 2021)

RMA-Net This repo is the implementation of the paper: Recurrent Multi-view Alignment Network for Unsupervised Surface Registration (CVPR 2021). Paper

Wanquan Feng 205 Nov 09, 2022
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning The predictive learning of spatiotemporal sequences aims to generate future

THUML: Machine Learning Group @ THSS 243 Dec 26, 2022
PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition

PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition The unofficial code of CDistNet. Now, we ha

25 Jul 20, 2022
Code for the Higgs Boson Machine Learning Challenge organised by CERN & EPFL

A method to solve the Higgs boson challenge using Least Squares - Novae This project is the Project 1 of EPFL CS-433 Machine Learning. The project is

Giacomo Orsi 1 Nov 09, 2021
Aesara is a Python library that allows one to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.

Aesara is a Python library that allows one to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.

Aesara 898 Jan 07, 2023
Non-Attentive-Tacotron - This is Pytorch Implementation of Google's Non-attentive Tacotron.

Non-attentive Tacotron - PyTorch Implementation This is Pytorch Implementation of Google's Non-attentive Tacotron, text-to-speech system. There is som

Jounghee Kim 46 Dec 19, 2022
YOLOv4 / Scaled-YOLOv4 / YOLO - Neural Networks for Object Detection (Windows and Linux version of Darknet )

Yolo v4, v3 and v2 for Windows and Linux (neural networks for object detection) Paper YOLO v4: https://arxiv.org/abs/2004.10934 Paper Scaled YOLO v4:

Alexey 20.2k Jan 09, 2023
Dialect classification

Dialect-Classification This repository presents the data that was used in a talk at ICKL-5 (5th International Conference on Kurdish Linguistics) at th

Kurdish-BLARK 0 Nov 12, 2021
mPose3D, a mmWave-based 3D human pose estimation model.

mPose3D, a mmWave-based 3D human pose estimation model.

KylinChen 35 Nov 08, 2022
Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis

Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis [Paper] [Online Demo] The following results are obtained by our SCUNet with purely syn

Kai Zhang 312 Jan 07, 2023
The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track.

ISC21-Descriptor-Track-1st The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track. You can check our solution

lyakaap 75 Jan 08, 2023
Unofficial implementation of Pix2SEQ

Unofficial-Pix2seq: A Language Modeling Framework for Object Detection Unofficial implementation of Pix2SEQ. Please use this code with causion. Many i

159 Dec 12, 2022
To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beginners, intermediates as well as experts

JaxTon 💯 JAX exercises Mission 🚀 To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beg

Rohan Rao 512 Jan 01, 2023
Pytorch implementation of BRECQ, ICLR 2021

BRECQ Pytorch implementation of BRECQ, ICLR 2021 @inproceedings{ li&gong2021brecq, title={BRECQ: Pushing the Limit of Post-Training Quantization by Bl

Yuhang Li 148 Dec 28, 2022
Camview - A CLI-tool used to stream CCTV online footage based on URL params

CamView A CLI-tool used to stream CCTV online footage based on URL params Get St

Finn Lancaster 54 Dec 09, 2022
Code of paper "CDFI: Compression-Driven Network Design for Frame Interpolation", CVPR 2021

CDFI (Compression-Driven-Frame-Interpolation) [Paper] (Coming soon...) | [arXiv] Tianyu Ding*, Luming Liang*, Zhihui Zhu, Ilya Zharkov IEEE Conference

Tianyu Ding 95 Dec 04, 2022
Just-Now - This Is Just Now Login Friendlist Cloner Tools

JUST NOW LOGIN FRIENDLIST CLONER TOOLS Install $ apt update $ apt upgrade $ apt

MAHADI HASAN AFRIDI 21 Mar 09, 2022
Graph Convolutional Networks for Temporal Action Localization (ICCV2019)

Graph Convolutional Networks for Temporal Action Localization This repo holds the codes and models for the PGCN framework presented on ICCV 2019 Graph

Runhao Zeng 318 Dec 06, 2022