Block Sparse movement pruning

Overview

Movement Pruning: Adaptive Sparsity by Fine-Tuning

Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; however, it is less effective in the transfer learning regime that has become standard for state-of-the-art natural language processing applications. We propose the use of movement pruning, a simple, deterministic first-order weight pruning method that is more adaptive to pretrained model fine-tuning. Experiments show that when pruning large pretrained language models, movement pruning shows significant improvements in high-sparsity regimes. When combined with distillation, the approach achieves minimal accuracy loss with down to only 3% of the model parameters:

Fine-pruning+Distillation
(Teacher=BERT-base fine-tuned)
BERT base
fine-tuned
Remaining
Weights (%)
Magnitude Pruning L0 Regularization Movement Pruning Soft Movement Pruning
SQuAD - Dev
EM/F1
80.4/88.1 10%
3%
70.2/80.1
45.5/59.6
72.4/81.9
64.3/75.8
75.6/84.3
67.5/78.0
76.6/84.9
72.7/82.3
MNLI - Dev
acc/MM acc
84.5/84.9 10%
3%
78.3/79.3
69.4/70.6
78.7/79.7
76.0/76.2
80.1/80.4
76.5/77.4
81.2/81.8
79.5/80.1
QQP - Dev
acc/F1
91.4/88.4 10%
3%
79.8/65.0
72.4/57.8
88.1/82.8
87.0/81.9
89.7/86.2
86.1/81.5
90.2/86.8
89.1/85.5

This page contains information on how to fine-prune pre-trained models such as BERT to obtain extremely sparse models with movement pruning. In contrast to magnitude pruning which selects weights that are far from 0, movement pruning retains weights that are moving away from 0.

For more information, we invite you to check out our paper. You can also have a look at this fun Explain Like I'm Five introductory slide deck.

Extreme sparsity and efficient storage

One promise of extreme pruning is to obtain extremely small models that can be easily sent (and stored) on edge devices. By setting weights to 0., we reduce the amount of information we need to store, and thus decreasing the memory size. We are able to obtain extremely sparse fine-pruned models with movement pruning: ~95% of the dense performance with ~5% of total remaining weights in the BERT encoder.

In this notebook, we showcase how we can leverage standard tools that exist out-of-the-box to efficiently store an extremely sparse question answering model (only 6% of total remaining weights in the encoder). We are able to reduce the memory size of the encoder from the 340MB (the orignal dense BERT) to 11MB, without any additional training of the model (every operation is performed post fine-pruning). It is sufficiently small to store it on a 91' floppy disk 📎 !

While movement pruning does not directly optimize for memory footprint (but rather the number of non-null weights), we hypothetize that further memory compression ratios can be achieved with specific quantization aware trainings (see for instance Q8BERT, And the Bit Goes Down or Quant-Noise).

Fine-pruned models

As examples, we release two English PruneBERT checkpoints (models fine-pruned from a pre-trained BERT checkpoint), one on SQuAD and the other on MNLI.

  • prunebert-base-uncased-6-finepruned-w-distil-squad
    Pre-trained BERT-base-uncased fine-pruned with soft movement pruning on SQuAD v1.1. We use an additional distillation signal from BERT-base-uncased finetuned on SQuAD. The encoder counts 6% of total non-null weights and reaches 83.8 F1 score. The model can be accessed with: pruned_bert = BertForQuestionAnswering.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad")
  • prunebert-base-uncased-6-finepruned-w-distil-mnli
    Pre-trained BERT-base-uncased fine-pruned with soft movement pruning on MNLI. We use an additional distillation signal from BERT-base-uncased finetuned on MNLI. The encoder counts 6% of total non-null weights and reaches 80.7 (matched) accuracy. The model can be accessed with: pruned_bert = BertForSequenceClassification.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-mnli")

How to fine-prune?

Setup

The code relies on the 🤗 Transformers library. In addition to the dependencies listed in the examples folder, you should install a few additional dependencies listed in the requirements.txt file: pip install -r requirements.txt.

Note that we built our experiments on top of a stabilized version of the library (commit https://github.com/huggingface/transformers/commit/352d5472b0c1dec0f420d606d16747d851b4bda8): we do not guarantee that everything is still compatible with the latest version of the master branch.

Fine-pruning with movement pruning

Below, we detail how to reproduce the results reported in the paper. We use SQuAD as a running example. Commands (and scripts) can be easily adapted for other tasks.

The following command fine-prunes a pre-trained BERT-base on SQuAD using movement pruning towards 15% of remaining weights (85% sparsity). Note that we freeze all the embeddings modules (from their pre-trained value) and only prune the Fully Connected layers in the encoder (12 layers of Transformer Block).

SERIALIZATION_DIR=<OUTPUT_DIR>
SQUAD_DATA=squad_data

mkdir $SQUAD_DATA
cd $SQUAD_DATA
wget -q https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
wget -q https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
cd ..


python examples/movement-pruning/masked_run_squad.py \
    --output_dir $SERIALIZATION_DIR \
    --data_dir $SQUAD_DATA \
    --train_file train-v1.1.json \
    --predict_file dev-v1.1.json \
    --do_train --do_eval --do_lower_case \
    --model_type masked_bert \
    --model_name_or_path bert-base-uncased \
    --per_gpu_train_batch_size 16 \
    --warmup_steps 5400 \
    --num_train_epochs 10 \
    --learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
    --initial_threshold 1 --final_threshold 0.15 \
    --initial_warmup 1 --final_warmup 2 \
    --pruning_method topK --mask_init constant --mask_scale 0.

Fine-pruning with other methods

We can also explore other fine-pruning methods by changing the pruning_method parameter:

Soft movement pruning

python examples/movement-pruning/masked_run_squad.py \
    --output_dir $SERIALIZATION_DIR \
    --data_dir $SQUAD_DATA \
    --train_file train-v1.1.json \
    --predict_file dev-v1.1.json \
    --do_train --do_eval --do_lower_case \
    --model_type masked_bert \
    --model_name_or_path bert-base-uncased \
    --per_gpu_train_batch_size 16 \
    --warmup_steps 5400 \
    --num_train_epochs 10 \
    --learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
    --initial_threshold 0 --final_threshold 0.1 \
    --initial_warmup 1 --final_warmup 2 \
    --pruning_method sigmoied_threshold --mask_init constant --mask_scale 0. \
    --regularization l1 --final_lambda 400.

L0 regularization

python examples/movement-pruning/masked_run_squad.py \
    --output_dir $SERIALIZATION_DIR \
    --data_dir $SQUAD_DATA \
    --train_file train-v1.1.json \
    --predict_file dev-v1.1.json \
    --do_train --do_eval --do_lower_case \
    --model_type masked_bert \
    --model_name_or_path bert-base-uncased \
    --per_gpu_train_batch_size 16 \
    --warmup_steps 5400 \
    --num_train_epochs 10 \
    --learning_rate 3e-5 --mask_scores_learning_rate 1e-1 \
    --initial_threshold 1. --final_threshold 1. \
    --initial_warmup 1 --final_warmup 1 \
    --pruning_method l0 --mask_init constant --mask_scale 2.197 \
    --regularization l0 --final_lambda 125.

Iterative Magnitude Pruning

python examples/movement-pruning/masked_run_squad.py \
    --output_dir ./dbg \
    --data_dir examples/distillation/data/squad_data \
    --train_file train-v1.1.json \
    --predict_file dev-v1.1.json \
    --do_train --do_eval --do_lower_case \
    --model_type masked_bert \
    --model_name_or_path bert-base-uncased \
    --per_gpu_train_batch_size 16 \
    --warmup_steps 5400 \
    --num_train_epochs 10 \
    --learning_rate 3e-5 \
    --initial_threshold 1 --final_threshold 0.15 \
    --initial_warmup 1 --final_warmup 2 \
    --pruning_method magnitude

After fine-pruning

Counting parameters

Regularization based pruning methods (soft movement pruning and L0 regularization) rely on the penalty to induce sparsity. The multiplicative coefficient controls the sparsity level. To obtain the effective sparsity level in the encoder, we simply count the number of activated (non-null) weights:

python examples/movement-pruning/counts_parameters.py \
    --pruning_method sigmoied_threshold \
    --threshold 0.1 \
    --serialization_dir $SERIALIZATION_DIR

Pruning once for all

Once the model has been fine-pruned, the pruned weights can be set to 0. once for all (reducing the amount of information to store). In our running experiments, we can convert a MaskedBertForQuestionAnswering (a BERT model augmented to enable on-the-fly pruning capabilities) to a standard BertForQuestionAnswering:

python examples/movement-pruning/bertarize.py \
    --pruning_method sigmoied_threshold \
    --threshold 0.1 \
    --model_name_or_path $SERIALIZATION_DIR

Hyper-parameters

For reproducibility purposes, we share the detailed results presented in the paper. These tables exhaustively describe the individual hyper-parameters used for each data point.

Inference speed

Early experiments show that even though models fine-pruned with (soft) movement pruning are extremely sparse, they do not benefit from significant improvement in terms of inference speed when using the standard PyTorch inference. We are currently benchmarking and exploring inference setups specifically for sparse architectures. In particular, hardware manufacturers are announcing devices that will speedup inference for sparse networks considerably.

Citation

If you find this resource useful, please consider citing the following paper:

@article{sanh2020movement,
    title={Movement Pruning: Adaptive Sparsity by Fine-Tuning},
    author={Victor Sanh and Thomas Wolf and Alexander M. Rush},
    year={2020},
    eprint={2005.07683},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
Owner
Hugging Face
Solving NLP, one commit at a time!
Hugging Face
This project is for a Twitter bot that monitors a bird feeder in my backyard. Any detected birds are identified and posted to Twitter.

Backyard Birdbot Introduction This is a silly hobby project to use existing ML models to: Detect any birds sighted by a webcam Identify whic

Chi Young Moon 71 Dec 25, 2022
Interacting Two-Hand 3D Pose and Shape Reconstruction from Single Color Image (ICCV 2021)

Interacting Two-Hand 3D Pose and Shape Reconstruction from Single Color Image Interacting Two-Hand 3D Pose and Shape Reconstruction from Single Color

75 Dec 02, 2022
Official code for "Towards An End-to-End Framework for Flow-Guided Video Inpainting" (CVPR2022)

E2FGVI (CVPR 2022) English | 简体中文 This repository contains the official implementation of the following paper: Towards An End-to-End Framework for Flo

Media Computing Group @ Nankai University 537 Jan 07, 2023
COPA-SSE contains crowdsourced explanations for the Balanced COPA dataset

COPA-SSE Repository for COPA-SSE: Semi-Structured Explanations for Commonsense Reasoning. COPA-SSE contains crowdsourced explanations for the Balanced

Ana Brassard 5 Jul 31, 2022
An elaborate and exhaustive paper list for Named Entity Recognition (NER)

Named-Entity-Recognition-NER-Papers by Pengfei Liu, Jinlan Fu and other contributors. An elaborate and exhaustive paper list for Named Entity Recognit

Pengfei Liu 388 Dec 18, 2022
Leveraging Social Influence based on Users Activity Centers for Point-of-Interest Recommendation

SUCP Leveraging Social Influence based on Users Activity Centers for Point-of-Interest Recommendation () Direct Friends (i.e., users who follow each o

Kosar 8 Nov 26, 2022
A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks

Spiking Neural Network training with EventProp This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiki

Pedro Savarese 35 Jul 29, 2022
Torch code for our CVPR 2018 paper "Residual Dense Network for Image Super-Resolution" (Spotlight)

Residual Dense Network for Image Super-Resolution This repository is for RDN introduced in the following paper Yulun Zhang, Yapeng Tian, Yu Kong, Bine

Yulun Zhang 494 Dec 30, 2022
Subdivision-based Mesh Convolutional Networks

Subdivision-based Mesh Convolutional Networks The official implementation of SubdivNet in our paper, Subdivion-based Mesh Convolutional Networks Requi

Zheng-Ning Liu 181 Dec 28, 2022
PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)

About PyTorch 1.2.0 Now the master branch supports PyTorch 1.2.0 by default. Due to the serious version problem (especially torch.utils.data.dataloade

Sanghyun Son 2.1k Dec 27, 2022
High-Resolution 3D Human Digitization from A Single Image.

PIFuHD: Multi-Level Pixel-Aligned Implicit Function for High-Resolution 3D Human Digitization (CVPR 2020) News: [2020/06/15] Demo with Google Colab (i

Meta Research 8.4k Dec 29, 2022
Camera Distortion-aware 3D Human Pose Estimation in Video with Optimization-based Meta-Learning

Camera Distortion-aware 3D Human Pose Estimation in Video with Optimization-based Meta-Learning This is the official repository of "Camera Distortion-

Hanbyel Cho 12 Oct 06, 2022
PyTorch implementation of the NIPS-17 paper "Poincaré Embeddings for Learning Hierarchical Representations"

Poincaré Embeddings for Learning Hierarchical Representations PyTorch implementation of Poincaré Embeddings for Learning Hierarchical Representations

Facebook Research 1.6k Dec 25, 2022
BT-Unet: A-Self-supervised-learning-framework-for-biomedical-image-segmentation-using-Barlow-Twins

BT-Unet: A-Self-supervised-learning-framework-for-biomedical-image-segmentation-using-Barlow-Twins Deep learning has brought most profound contributio

Narinder Singh Punn 12 Dec 04, 2022
OMNIVORE is a single vision model for many different visual modalities

Omnivore: A Single Model for Many Visual Modalities [paper][website] OMNIVORE is a single vision model for many different visual modalities. It learns

Meta Research 451 Dec 27, 2022
A3C LSTM Atari with Pytorch plus A3G design

NEWLY ADDED A3G A NEW GPU/CPU ARCHITECTURE OF A3C FOR SUBSTANTIALLY ACCELERATED TRAINING!! RL A3C Pytorch NEWLY ADDED A3G!! New implementation of A3C

David Griffis 532 Jan 02, 2023
Self-driving car env with PPO algorithm from stable baseline3

Self-driving car with RL stable baseline3 Most of the project develop from https://github.com/GerardMaggiolino/Gym-Medium-Post Please check it out! Th

Sornsiri.P 7 Dec 22, 2022
Implementation of light baking system for ray tracing based on Activision's UberBake

Vulkan Light Bakary MSU Graphics Group Student's Diploma Project Treefonov Andrey [GitHub] [LinkedIn] Project Goal The goal of the project is to imple

Andrey Treefonov 7 Dec 27, 2022
tmm_fast is a lightweight package to speed up optical planar multilayer thin-film device computation.

tmm_fast tmm_fast or transfer-matrix-method_fast is a lightweight package to speed up optical planar multilayer thin-film device computation. It is es

26 Dec 11, 2022
Official implementation of "Not only Look, but also Listen: Learning Multimodal Violence Detection under Weak Supervision" ECCV2020

XDVioDet Official implementation of "Not only Look, but also Listen: Learning Multimodal Violence Detection under Weak Supervision" ECCV2020. The proj

peng 64 Dec 12, 2022