Learned Token Pruning for Transformers

Overview

LTP: Learned Token Pruning for Transformers

Screenshot from 2021-07-08 13-39-02

Screenshot from 2021-07-08 13-39-28

Check our paper for more details.

Installation

We follow the same installation procedure as the original Huggingface transformer repo.

pip install sklearn scipy datasets torch
pip install -e .  # in the top directory

Prepare Checkpoints

LTP is implemented on top of Huggingface transformer's I-BERT implementation. Therefore, we first need to generate a checkpoint file of ibert finetuned on the target downstream task. While you can do this on the original Huggingface repository, we also support our base branch ltp/base where you can run the following code to finetune ibert on the GLUE tasks.

git checkout ltp/base
cd examples/text-classification
python run_glue.py --model_name_or_path kssteven/ibert-roberta-base --output_dir {CKPT} --task {TASK} --do_train --do_eval {--some_more_arguments}
  • {TASK}: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI
  • Please refer to the Huggingface tutorial and the official documentation for more details in arguments and hyperparameters.
  • Note that as default ibert behaves the same as roberta (see this tutorial), hence the resulting model will be the same as roberta-base finetuned on the target GLUE task.

The final model will be checkpointed in {CKPT}.

  • Remove {CKPT}/trainer_state.json.
  • In the configuration file {CKPT}/config.json, change (1) "architectures" to ["LTPForSequenceClassification"] and (2) "model_type" to "ltp".

Run Learned Token Pruning

Add the following lines in the configuration file {CKPT}/config.json.

"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01, 
"scoring_mode": "mean",

final_token_threshold determines the token threshold of the last layer, and the thresholds of the remaining layers will be linearly scaled. For instance, the thresholds for the 3rd, 6th, and 9th layers will be 0.0025, 0.005, and 0.0075, respectively, when setting the final_token_threshold , i.e., the threshold for the last (12th) layer, to 0.01. This number is a hyperparameter, and we found that 0.01 works well in many cases.

The learnable mode consists of 2 stages: soft threshold and hard threshold. Please refer to our paper for more details.

1. Soft Threshold

We first train the model using the soft threshold mode. This trains the thresholds as well as the model parameters to search for the best threshold configuration.

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr 2e-5 --temperature {T}\
  --lambda 0.1 --weight_decay 0 --bs 64 --masking_mode soft --epoch {epoch} --save_step 100 --no_load
  • {TASK}: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI
  • You can assign different learning rate for lr, but 2e-5 worked fine.
  • We set {epoch} to be 10 for smaller datasets (e.g., RTE, MRPC) and 1 for larger datasets (e.g., SST2, QNLI, MRPC).
  • --no_load flag will not load the best model at the end of the training (i.e., the final checkpoint will be the one at the end of training).
  • lambda is an important hyperparameter than controls the pruning level: the higher the value, the more we prune tokens. 0.01 ~ 0.2 worked well in many cases, but we recommend the user to empirically search for the best number for it.
  • temperature is another hyperparameter, and 1e-3 ~ 1e-5 worked well. In the paper, we searched over {1e−4, 2e−4, 5e−4, 1e−3, 2e−3}.

The final model will be checkpointed in {CKPT_soft} = checkpoints/base/{TASK}/absolute_threshold/rate_{final_token_threshold}/temperature_{T}/lambda_{lambda}/lr_{lr}. Remove trainer_state.json from the checkpoint file in {CKPT_soft}.

2. Hard Threshold

Once we learn the thresholds, we fix those values, turn back to the hard threshold mode, and finetune the model parameters only.

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT_soft} --lr {LR} --bs 64 --masking_mode hard --epoch 5 
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.

The final model will be checkpointed in {CKPT_soft}/hard/lr_{LR}.

Run Baseline Methods

We additionally provide code to reproduce the baseline methods used in our paper (i.e., top-k and manual threshold).

Top-k Token Pruning

Add the following lines in {CKPT}/config.json.

"prune_mode": "topk",
"token_keep_rate": 0.2,

The token keep rates of the first three layers and the last layer are 1 and token_keep_rate, respectively. The keep rates of the remaining layers are scaled linearly. The smaller token_keep_rate is, the more aggressive we prune tokens. You can also assign negative number for token_keep_rate and, in that case, the keep rate of each layer will be assigned as max(0, keep_rate).

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.

The final model will be checkpointed in {CKPT}/topk/lr_{LR}.

Manual(Non-leanrable) Threshold Pruning

Add the following lines in {CKPT}/config.json.

"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01, 
"scoring_mode": "mean",

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5 --save_step 500
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.
  • Note that the only difference from the learned token pruning mode is that we run the hard threshold mode from the beginning.

The final model will be checkpointed in {CKPT}/hard/lr_{LR}.

Yolo ros - YOLO-ROS for HUAWEI ATLAS200

YOLO-ROS YOLO-ROS for NVIDIA YOLO-ROS for HUAWEI ATLAS200, please checkout for b

ChrisLiu 5 Oct 18, 2022
Heterogeneous Temporal Graph Neural Network

Heterogeneous Temporal Graph Neural Network This repository contains the datasets and source code of HTGNN. run_mag.ipynb is the training and testing

15 Dec 22, 2022
VoxHRNet - Whole Brain Segmentation with Full Volume Neural Network

VoxHRNet This is the official implementation of the following paper: Whole Brain Segmentation with Full Volume Neural Network Yeshu Li, Jonathan Cui,

Microsoft 12 Nov 24, 2022
Multistream CNN for Robust Acoustic Modeling

Multistream Convolutional Neural Network (CNN) A multistream CNN is a novel neural network architecture for robust acoustic modeling in speech recogni

ASAPP Research 37 Sep 21, 2022
You can draw the corresponding bounding box into the image and save it according to the result file (txt format) run by the tracker.

You can draw the corresponding bounding box into the image and save it according to the result file (txt format) run by the tracker.

Huiyiqianli 42 Dec 06, 2022
The official project of SimSwap (ACM MM 2020)

SimSwap: An Efficient Framework For High Fidelity Face Swapping Proceedings of the 28th ACM International Conference on Multimedia The official reposi

Six_God 2.6k Jan 08, 2023
Py-faster-rcnn - Faster R-CNN (Python implementation)

py-faster-rcnn has been deprecated. Please see Detectron, which includes an implementation of Mask R-CNN. Disclaimer The official Faster R-CNN code (w

Ross Girshick 7.8k Jan 03, 2023
Code for the paper "JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design"

JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design This repository contains code for the paper: JA

Aspuru-Guzik group repo 55 Nov 29, 2022
Learning cell communication from spatial graphs of cells

ncem Features Repository for the manuscript Fischer, D. S., Schaar, A. C. and Theis, F. Learning cell communication from spatial graphs of cells. 2021

Theis Lab 77 Dec 30, 2022
Laser device for neutralizing - mosquitoes, weeds and pests

Laser device for neutralizing - mosquitoes, weeds and pests (in progress) Here I will post information for creating a laser device. A warning!! How It

Ildaron 1k Jan 02, 2023
Robot Servers and Server Manager software for robo-gym

robo-gym-server-modules Robot Servers and Server Manager software for robo-gym. For info on how to use this package please visit the robo-gym website

JR ROBOTICS 4 Aug 16, 2021
blind SQLIpy sebuah alat injeksi sql yang menggunakan waktu sql untuk mendapatkan sebuah server database.

blind SQLIpy Alat blind SQLIpy ini merupakan alat injeksi sql yang menggunakan metode time based blind sql injection metode tersebut membutuhkan waktu

Galih Anggoro Prasetya 4 Feb 24, 2022
A Marvelous ChatBot implement using PyTorch.

PyTorch Marvelous ChatBot [Update] it's 2019 now, previously model can not catch up state-of-art now. So we just move towards the future a transformer

JinTian 223 Oct 18, 2022
Lightweight library to build and train neural networks in Theano

Lasagne Lasagne is a lightweight library to build and train neural networks in Theano. Its main features are: Supports feed-forward networks such as C

Lasagne 3.8k Dec 29, 2022
Official Pytorch implementation of the paper "Action-Conditioned 3D Human Motion Synthesis with Transformer VAE", ICCV 2021

ACTOR Official Pytorch implementation of the paper "Action-Conditioned 3D Human Motion Synthesis with Transformer VAE", ICCV 2021. Please visit our we

Mathis Petrovich 248 Dec 23, 2022
Python script that takes an Impulse response .wav and a input .wav to demonstrate audio convolution.

convolver Python script that takes an Impulse response .wav and a input .wav to demonstrate audio convolution. Created by Sean Higley

Sean Higley 1 Feb 23, 2022
《Lerning n Intrinsic Grment Spce for Interctive Authoring of Grment Animtion》

Learning an Intrinsic Garment Space for Interactive Authoring of Garment Animation Overview This is the demo code for training a motion invariant enco

YuanBo 213 Dec 14, 2022
[NeurIPS 2021] Low-Rank Subspaces in GANs

Low-Rank Subspaces in GANs Figure: Image editing results using LowRankGAN on StyleGAN2 (first three columns) and BigGAN (last column). Low-Rank Subspa

112 Dec 28, 2022
source code and pre-trained/fine-tuned checkpoint for NAACL 2021 paper LightningDOT

LightningDOT: Pre-training Visual-Semantic Embeddings for Real-Time Image-Text Retrieval This repository contains source code and pre-trained/fine-tun

Siqi 65 Dec 26, 2022
Self-supervised learning optimally robust representations for domain generalization.

OptDom: Learning Optimal Representations for Domain Generalization This repository contains the official implementation for Optimal Representations fo

Yangjun Ruan 18 Aug 25, 2022