About this repository
This repo contains an Pytorch implementation for the ACL 2017 paper Get To The Point: Summarization with Pointer-Generator Networks. The code framework is based on TextBox.
Environment
python >= 3.8.11torch >= 1.6.0
Run install.sh to install other requirements.
Dataset
The processed dataset can be downloaded from Google Drive. Once finished, unzip the datafiles (train.src, train.tgt, ...) to ./data.
An overview of dataset: train: 287113 cases, dev: 13368 cases, test: 11490 cases
Paramters
# overall settings
data_path: 'data/'
checkpoint_dir: 'saved/'
generated_text_dir: 'generated/'
# dataset settings
max_vocab_size: 50000
src_len: 400
tgt_len: 100
# model settngs
decoding_strategy: 'beam_search'
beam_size: 4
is_attention: True
is_pgen: True
is_coverage: True
cov_loss_lambda: 1.0
Log file is located in ./log, more details can be found in yamls.
Note: Distributed Data Parallel (DDP) is not supported yet.
Train & Evaluation
From scratch run fire.py.
if __name__ == '__main__':
config = Config(config_dict={'test_only': False,
'load_experiment': None})
train(config)
If you want to resume from a checkpoint, just set the 'load_experiment': './saved/$model_name$.pth'. Similarly, when 'test_only' is set to True, 'load_experiment' is required.
Results
The best model is trained on a TITAN Xp GPU (8GB usage).
Training loss
Ablation study
| Model | Rouge-1 | Rouge-2 | Rouge-L |
|---|---|---|---|
| Seq2Seq | 22.17 | 7.20 | 20.97 |
| Seq2Seq+attn | 29.35 | 12.58 | 27.38 |
| Seq2Seq+attn+pgen | 36.04 | 15.87 | 32.92 |
| Seq2Seq+attn+pgen+coverage | 39.52 | 17.85 | 36.40 |
Note: The architecture of the Seq2Seq model is based on lstm, I hope I can replace it with transformer in the future.
