A Closer Look at Structured Pruning for Neural Network Compression

Overview

A Closer Look at Structured Pruning for Neural Network Compression

Code used to reproduce experiments in https://arxiv.org/abs/1810.04622.

To prune, we fill our networks with custom MaskBlocks, which are manipulated using Pruner in funcs.py. There will certainly be a better way to do this, but we leave this as an exercise to someone who can code much better than we can.

Setup

This is best done in a clean conda environment:

conda create -n prunes python=3.6
conda activate prunes
conda install pytorch torchvision -c pytorch

Repository layout

-train.py: contains all of the code for training large models from scratch and for training pruned models from scratch
-prune.py: contains the code for pruning trained models
-funcs.py: contains useful pruning functions and any functions we used commonly

CIFAR Experiments

First, you will need some initial models.

To train a WRN-40-2:

python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res'

The default arguments of train.py are suitable for training WRNs. The following trains a DenseNet-BC-100 (k=12) with its default hyperparameters:

python train.py --net='dense' --depth=100 --data_loc= --save_file='dense' --no_epochs 300 -b 64 --epoch_step '[150,225]' --weight_decay 0.0001 --lr_decay_ratio 0.1

These will automatically save checkpoints to the checkpoints folder.

Pruning

Once training is finished, we can prune our networks using prune.py (defaults are set to WRN pruning, so extra arguments are needed for DenseNets)

python prune.py --net='res'   --data_loc= --base_model='res' --save_file='res_fisher'
python prune.py --net='res'   --data_loc= --l1_prune=True --base_model='res' --save_file='res_l1'

python prune.py --net='dense' --depth 100 --data_loc= --base_model='dense' --save_file='dense_fisher' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600
python prune.py --net='dense' --depth 100 --data_loc= --l1_prune=True --base_model='dense' --save_file='dense_l1'  --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64  --no_epochs 2600

Note that the default is to perform Fisher pruning, so you don't need to pass a flag to use it.
Once finished, we can train the pruned models from scratch, e.g.:

python train.py --data_loc= --net='res' --base_file='res_fisher__prunes' --deploy --mask=1 --save_file='res_fisher__prunes_scratch'

Each model can then be evaluated using:

python train.py --deploy --eval --data_loc= --net='res' --mask=1 --base_file='res_fisher__prunes'

Training Reduced models

This can be done by varying the input arguments to train.py. To reduce depth or width of a WRN, change the corresponding option:

python train.py --net='res' --depth= --width= --data_loc= --save_file='res_reduced'

To add bottlenecks, use the following:

python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res_bottle' --bottle --bottle_mult 

With DenseNets you can modify the depth or growth, or use --bottle --bottle_mult as above.

Acknowledgements

Jack Turner wrote the L1 stuff, and some other stuff for that matter.

Code has been liberally borrowed from many a repo, including, but not limited to:

https://github.com/xternalz/WideResNet-pytorch
https://github.com/bamos/densenet.pytorch
https://github.com/kuangliu/pytorch-cifar
https://github.com/ShichenLiu/CondenseNet

Citing this work

If you would like to cite this work, please use the following bibtex entry:

@article{crowley2018pruning,
  title={A Closer Look at Structured Pruning for Neural Network Compression},
  author={Crowley, Elliot J and Turner, Jack and Storkey, Amos and O'Boyle, Michael},
  journal={arXiv preprint arXiv:1810.04622},
  year={2018},
  }
Owner
Bayesian and Neural Systems Group
Machine learning research group @ University of Edinburgh
Bayesian and Neural Systems Group
Indonesian Car License Plate Character Recognition using Tensorflow, Keras and OpenCV.

Monopol Indonesian Car License Plate (Indonesia Mobil Nomor Polisi) Character Recognition using Tensorflow, Keras and OpenCV. Background This applicat

Jayaku Briliantio 3 Apr 07, 2022
Autonomous Perception: 3D Object Detection with Complex-YOLO

Autonomous Perception: 3D Object Detection with Complex-YOLO LiDAR object detect

Thomas Dunlap 2 Feb 18, 2022
Human segmentation models, training/inference code, and trained weights, implemented in PyTorch

Human-Segmentation-PyTorch Human segmentation models, training/inference code, and trained weights, implemented in PyTorch. Supported networks UNet: b

Thuy Ng 474 Dec 19, 2022
LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs

LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs This is the code for the LERP. Dataset The dataset used is MI

5 Jun 18, 2022
The self-supervised goal reaching benchmark introduced in Discovering and Achieving Goals via World Models

Lexa-Benchmark Codebase for the self-supervised goal reaching benchmark introduced in 'Discovering and Achieving Goals via World Models'. Setup Create

1 Oct 14, 2021
Using BERT+Bi-LSTM+CRF

Chinese Medical Entity Recognition Based on BERT+Bi-LSTM+CRF Step 1 I share the dataset on my google drive, please download the whole 'CCKS_2019_Task1

Xiang WU 55 Dec 21, 2022
Unifying Global-Local Representations in Salient Object Detection with Transformer

GLSTR (Global-Local Saliency Transformer) This is the official implementation of paper "Unifying Global-Local Representations in Salient Object Detect

11 Aug 24, 2022
PyTorch implementation of federated learning framework based on the acceleration of global momentum

Federated Learning with Acceleration of Global Momentum PyTorch implementation of federated learning framework based on the acceleration of global mom

0 Dec 23, 2021
JittorVis - Visual understanding of deep learning models

JittorVis: Visual understanding of deep learning model JittorVis is an open-source library for understanding the inner workings of Jittor models by vi

thu-vis 182 Jan 06, 2023
Specificity-preserving RGB-D Saliency Detection

Specificity-preserving RGB-D Saliency Detection Authors: Tao Zhou, Huazhu Fu, Geng Chen, Yi Zhou, Deng-Ping Fan, and Ling Shao. 1. Preface This reposi

Tao Zhou 35 Jan 08, 2023
Sub-Cluster AdaCos: Learning Representations for Anomalous Sound Detection.

Accompanying code for the paper Sub-Cluster AdaCos: Learning Representations for Anomalous Sound Detection.

Kevin Wilkinghoff 6 Dec 01, 2022
FCOSR: A Simple Anchor-free Rotated Detector for Aerial Object Detection

FCOSR: A Simple Anchor-free Rotated Detector for Aerial Object Detection FCOSR: A Simple Anchor-free Rotated Detector for Aerial Object Detection arXi

59 Nov 29, 2022
Realtime segmentation with ENet, the fast and accurate segmentation net.

Enet This is a realtime segmentation net with almost 22 fps on GTX1080 ti, and the model size is very small with only 28M. This repo contains the infe

JinTian 14 Aug 30, 2022
StyleGAN2 - Official TensorFlow Implementation

StyleGAN2 - Official TensorFlow Implementation

NVIDIA Research Projects 10.1k Dec 28, 2022
Plugin adapted from Ultralytics to bring YOLOv5 into Napari

napari-yolov5 Plugin adapted from Ultralytics to bring YOLOv5 into Napari. Training and detection can be done using the GUI. Training dataset must be

2 May 05, 2022
Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) - PyTorch Implementation

Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding PyTorch implementation for the Scalable Attentive Sentence-Pair Modeling vi

Microsoft 25 Dec 02, 2022
classify fashion-mnist dataset with pytorch

Fashion-Mnist Classifier with PyTorch Inference 1- clone this repository: git clone https://github.com/Jhamed7/Fashion-Mnist-Classifier.git 2- Instal

1 Jan 14, 2022
TinyML Cookbook, published by Packt

TinyML Cookbook This is the code repository for TinyML Cookbook, published by Packt. Author: Gian Marco Iodice Publisher: Packt About the book This bo

Packt 93 Dec 29, 2022
A Fast Sequence Transducer Implementation with PyTorch Bindings

transducer A Fast Sequence Transducer Implementation with PyTorch Bindings. The corresponding publication is Sequence Transduction with Recurrent Neur

Awni Hannun 184 Dec 18, 2022