Learned Token Pruning for Transformers

Overview

LTP: Learned Token Pruning for Transformers

Screenshot from 2021-07-08 13-39-02

Screenshot from 2021-07-08 13-39-28

Check our paper for more details.

Installation

We follow the same installation procedure as the original Huggingface transformer repo.

pip install sklearn scipy datasets torch
pip install -e .  # in the top directory

Prepare Checkpoints

LTP is implemented on top of Huggingface transformer's I-BERT implementation. Therefore, we first need to generate a checkpoint file of ibert finetuned on the target downstream task. While you can do this on the original Huggingface repository, we also support our base branch ltp/base where you can run the following code to finetune ibert on the GLUE tasks.

git checkout ltp/base
cd examples/text-classification
python run_glue.py --model_name_or_path kssteven/ibert-roberta-base --output_dir {CKPT} --task {TASK} --do_train --do_eval {--some_more_arguments}
  • {TASK}: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI
  • Please refer to the Huggingface tutorial and the official documentation for more details in arguments and hyperparameters.
  • Note that as default ibert behaves the same as roberta (see this tutorial), hence the resulting model will be the same as roberta-base finetuned on the target GLUE task.

The final model will be checkpointed in {CKPT}.

  • Remove {CKPT}/trainer_state.json.
  • In the configuration file {CKPT}/config.json, change (1) "architectures" to ["LTPForSequenceClassification"] and (2) "model_type" to "ltp".

Run Learned Token Pruning

Add the following lines in the configuration file {CKPT}/config.json.

"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01, 
"scoring_mode": "mean",

final_token_threshold determines the token threshold of the last layer, and the thresholds of the remaining layers will be linearly scaled. For instance, the thresholds for the 3rd, 6th, and 9th layers will be 0.0025, 0.005, and 0.0075, respectively, when setting the final_token_threshold , i.e., the threshold for the last (12th) layer, to 0.01. This number is a hyperparameter, and we found that 0.01 works well in many cases.

The learnable mode consists of 2 stages: soft threshold and hard threshold. Please refer to our paper for more details.

1. Soft Threshold

We first train the model using the soft threshold mode. This trains the thresholds as well as the model parameters to search for the best threshold configuration.

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr 2e-5 --temperature {T}\
  --lambda 0.1 --weight_decay 0 --bs 64 --masking_mode soft --epoch {epoch} --save_step 100 --no_load
  • {TASK}: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI
  • You can assign different learning rate for lr, but 2e-5 worked fine.
  • We set {epoch} to be 10 for smaller datasets (e.g., RTE, MRPC) and 1 for larger datasets (e.g., SST2, QNLI, MRPC).
  • --no_load flag will not load the best model at the end of the training (i.e., the final checkpoint will be the one at the end of training).
  • lambda is an important hyperparameter than controls the pruning level: the higher the value, the more we prune tokens. 0.01 ~ 0.2 worked well in many cases, but we recommend the user to empirically search for the best number for it.
  • temperature is another hyperparameter, and 1e-3 ~ 1e-5 worked well. In the paper, we searched over {1e−4, 2e−4, 5e−4, 1e−3, 2e−3}.

The final model will be checkpointed in {CKPT_soft} = checkpoints/base/{TASK}/absolute_threshold/rate_{final_token_threshold}/temperature_{T}/lambda_{lambda}/lr_{lr}. Remove trainer_state.json from the checkpoint file in {CKPT_soft}.

2. Hard Threshold

Once we learn the thresholds, we fix those values, turn back to the hard threshold mode, and finetune the model parameters only.

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT_soft} --lr {LR} --bs 64 --masking_mode hard --epoch 5 
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.

The final model will be checkpointed in {CKPT_soft}/hard/lr_{LR}.

Run Baseline Methods

We additionally provide code to reproduce the baseline methods used in our paper (i.e., top-k and manual threshold).

Top-k Token Pruning

Add the following lines in {CKPT}/config.json.

"prune_mode": "topk",
"token_keep_rate": 0.2,

The token keep rates of the first three layers and the last layer are 1 and token_keep_rate, respectively. The keep rates of the remaining layers are scaled linearly. The smaller token_keep_rate is, the more aggressive we prune tokens. You can also assign negative number for token_keep_rate and, in that case, the keep rate of each layer will be assigned as max(0, keep_rate).

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.

The final model will be checkpointed in {CKPT}/topk/lr_{LR}.

Manual(Non-leanrable) Threshold Pruning

Add the following lines in {CKPT}/config.json.

"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01, 
"scoring_mode": "mean",

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5 --save_step 500
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.
  • Note that the only difference from the learned token pruning mode is that we run the hard threshold mode from the beginning.

The final model will be checkpointed in {CKPT}/hard/lr_{LR}.

Model Serving Made Easy

The easiest way to build Machine Learning APIs BentoML makes moving trained ML models to production easy: Package models trained with any ML framework

BentoML 4.4k Jan 08, 2023
The repository contains source code and models to use PixelNet architecture used for various pixel-level tasks. More details can be accessed at .

PixelNet: Representation of the pixels, by the pixels, and for the pixels. We explore design principles for general pixel-level prediction problems, f

Aayush Bansal 196 Aug 10, 2022
Towards Part-Based Understanding of RGB-D Scans

Towards Part-Based Understanding of RGB-D Scans (CVPR 2021) We propose the task of part-based scene understanding of real-world 3D environments: from

26 Nov 23, 2022
The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"

Hierarchical Token Semantic Audio Transformer Introduction The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound

Knut(Ke) Chen 134 Jan 01, 2023
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 02, 2023
VGGFace2-HQ - A high resolution face dataset for face editing purpose

The first open source high resolution dataset for face swapping!!! A high resolution version of VGGFace2 for academic face editing purpose

Naiyuan Liu 232 Dec 29, 2022
Torch-ngp - A pytorch implementation of the hash encoder proposed in instant-ngp

HashGrid Encoder (WIP) A pytorch implementation of the HashGrid Encoder from ins

hawkey 1k Jan 01, 2023
[ICCV 2021] Our work presents a novel neural rendering approach that can efficiently reconstruct geometric and neural radiance fields for view synthesis.

MVSNeRF Project page | Paper This repository contains a pytorch lightning implementation for the ICCV 2021 paper: MVSNeRF: Fast Generalizable Radiance

Anpei Chen 529 Dec 30, 2022
A graphical Semi-automatic annotation tool based on labelImg and Yolov5

💕YOLOV5 semi-automatic annotation tool (Based on labelImg)

EricFang 247 Jan 05, 2023
General Assembly Capstone: NBA Game Predictor

Project 6: Predicting NBA Games Problem Statement Can I predict the results of NBA games from the back-half of a season from the opening half of the s

Adam Muhammad Klesc 1 Jan 14, 2022
Codes for CVPR2021 paper "PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization"

PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization (CVPR 2021) This is the official implementation of PW

Intelligent Robotics and Machine Vision Lab 42 Dec 18, 2022
We simulate traveling back in time with a modern camera to rephotograph famous historical subjects.

[SIGGRAPH Asia 2021] Time-Travel Rephotography [Project Website] Many historical people were only ever captured by old, faded, black and white photos,

298 Jan 02, 2023
some academic posters as references. May we have in-person poster session soon!

some academic posters as references. May we have in-person poster session soon!

Bolei Zhou 472 Jan 06, 2023
Ranger - a synergistic optimizer using RAdam (Rectified Adam), Gradient Centralization and LookAhead in one codebase

Ranger-Deep-Learning-Optimizer Ranger - a synergistic optimizer combining RAdam (Rectified Adam) and LookAhead, and now GC (gradient centralization) i

Less Wright 1.1k Dec 21, 2022
Camview - A CLI-tool used to stream CCTV online footage based on URL params

CamView A CLI-tool used to stream CCTV online footage based on URL params Get St

Finn Lancaster 54 Dec 09, 2022
DuBE: Duple-balanced Ensemble Learning from Skewed Data

DuBE: Duple-balanced Ensemble Learning from Skewed Data "Towards Inter-class and Intra-class Imbalance in Class-imbalanced Learning" (IEEE ICDE 2022 S

6 Nov 12, 2022
ElasticFace: Elastic Margin Loss for Deep Face Recognition

This is the official repository of the paper: ElasticFace: Elastic Margin Loss for Deep Face Recognition Paper on arxiv: arxiv Model Log file Pretrain

Fadi Boutros 113 Dec 14, 2022
Inflated i3d network with inception backbone, weights transfered from tensorflow

I3D models transfered from Tensorflow to PyTorch This repo contains several scripts that allow to transfer the weights from the tensorflow implementat

Yana 479 Dec 08, 2022
CLIPImageClassifier wraps clip image model from transformers

CLIPImageClassifier CLIPImageClassifier wraps clip image model from transformers. CLIPImageClassifier is initialized with the argument classes, these

Jina AI 6 Sep 12, 2022
Simulation of self-focusing of laser beams in condensed media

What is it? Program for scientific research, which allows to simulate the phenomenon of self-focusing of different laser beams (including Gaussian, ri

Evgeny Vasilyev 13 Dec 24, 2022