A Closer Look at Structured Pruning for Neural Network Compression

Overview

A Closer Look at Structured Pruning for Neural Network Compression

Code used to reproduce experiments in https://arxiv.org/abs/1810.04622.

To prune, we fill our networks with custom MaskBlocks, which are manipulated using Pruner in funcs.py. There will certainly be a better way to do this, but we leave this as an exercise to someone who can code much better than we can.

Setup

This is best done in a clean conda environment:

conda create -n prunes python=3.6
conda activate prunes
conda install pytorch torchvision -c pytorch

Repository layout

-train.py: contains all of the code for training large models from scratch and for training pruned models from scratch
-prune.py: contains the code for pruning trained models
-funcs.py: contains useful pruning functions and any functions we used commonly

CIFAR Experiments

First, you will need some initial models.

To train a WRN-40-2:

python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res'

The default arguments of train.py are suitable for training WRNs. The following trains a DenseNet-BC-100 (k=12) with its default hyperparameters:

python train.py --net='dense' --depth=100 --data_loc= --save_file='dense' --no_epochs 300 -b 64 --epoch_step '[150,225]' --weight_decay 0.0001 --lr_decay_ratio 0.1

These will automatically save checkpoints to the checkpoints folder.

Pruning

Once training is finished, we can prune our networks using prune.py (defaults are set to WRN pruning, so extra arguments are needed for DenseNets)

python prune.py --net='res'   --data_loc= --base_model='res' --save_file='res_fisher'
python prune.py --net='res'   --data_loc= --l1_prune=True --base_model='res' --save_file='res_l1'

python prune.py --net='dense' --depth 100 --data_loc= --base_model='dense' --save_file='dense_fisher' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600
python prune.py --net='dense' --depth 100 --data_loc= --l1_prune=True --base_model='dense' --save_file='dense_l1'  --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64  --no_epochs 2600

Note that the default is to perform Fisher pruning, so you don't need to pass a flag to use it.
Once finished, we can train the pruned models from scratch, e.g.:

python train.py --data_loc= --net='res' --base_file='res_fisher__prunes' --deploy --mask=1 --save_file='res_fisher__prunes_scratch'

Each model can then be evaluated using:

python train.py --deploy --eval --data_loc= --net='res' --mask=1 --base_file='res_fisher__prunes'

Training Reduced models

This can be done by varying the input arguments to train.py. To reduce depth or width of a WRN, change the corresponding option:

python train.py --net='res' --depth= --width= --data_loc= --save_file='res_reduced'

To add bottlenecks, use the following:

python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res_bottle' --bottle --bottle_mult 

With DenseNets you can modify the depth or growth, or use --bottle --bottle_mult as above.

Acknowledgements

Jack Turner wrote the L1 stuff, and some other stuff for that matter.

Code has been liberally borrowed from many a repo, including, but not limited to:

https://github.com/xternalz/WideResNet-pytorch
https://github.com/bamos/densenet.pytorch
https://github.com/kuangliu/pytorch-cifar
https://github.com/ShichenLiu/CondenseNet

Citing this work

If you would like to cite this work, please use the following bibtex entry:

@article{crowley2018pruning,
  title={A Closer Look at Structured Pruning for Neural Network Compression},
  author={Crowley, Elliot J and Turner, Jack and Storkey, Amos and O'Boyle, Michael},
  journal={arXiv preprint arXiv:1810.04622},
  year={2018},
  }
Owner
Bayesian and Neural Systems Group
Machine learning research group @ University of Edinburgh
Bayesian and Neural Systems Group
[AAAI22] Reliable Propagation-Correction Modulation for Video Object Segmentation

Reliable Propagation-Correction Modulation for Video Object Segmentation (AAAI22) Preview version paper of this work is available at: https://arxiv.or

Xiaohao Xu 70 Dec 04, 2022
Deep-learning X-Ray Micro-CT image enhancement, pore-network modelling and continuum modelling

EDSR modelling A Github repository for deep-learning image enhancement, pore-network and continuum modelling from X-Ray Micro-CT images. The repositor

Samuel Jackson 7 Nov 03, 2022
Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Training and Effective Adaptation", Haoxiang Wang, Han Zhao, Bo Li.

Bridging Multi-Task Learning and Meta-Learning Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Trainin

AI Secure 57 Dec 15, 2022
Create images and texts with the First Order Generative Adversarial Networks

First Order Divergence for training GANs This repository contains code accompanying the paper First Order Generative Advesarial Netoworks The majority

Zalando Research 35 Dec 11, 2021
Towards Interpretable Deep Metric Learning with Structural Matching

DIML Created by Wenliang Zhao*, Yongming Rao*, Ziyi Wang, Jiwen Lu, Jie Zhou This repository contains PyTorch implementation for paper Towards Interpr

Wenliang Zhao 75 Nov 11, 2022
Template repository for managing machine learning research projects built with PyTorch-Lightning

Tutorial Repository with a minimal example for showing how to deploy training across various compute infrastructure.

Sidd Karamcheti 3 Feb 11, 2022
An Exact Solver for Semi-supervised Minimum Sum-of-Squares Clustering

PC-SOS-SDP: an Exact Solver for Semi-supervised Minimum Sum-of-Squares Clustering PC-SOS-SDP is an exact algorithm based on the branch-and-bound techn

Antonio M. Sudoso 1 Nov 13, 2022
The Face Mask recognition system uses AI technology to detect the person with or without a mask.

Face Mask Detection Face Mask Detection system built with OpenCV, Keras/TensorFlow using Deep Learning and Computer Vision concepts in order to detect

Rohan Kasabe 4 Apr 05, 2022
Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Dominic Rampas 247 Dec 16, 2022
Code accompanying "Adaptive Methods for Aggregated Domain Generalization"

Adaptive Methods for Aggregated Domain Generalization (AdaClust) Official Pytorch Implementation of Adaptive Methods for Aggregated Domain Generalizat

Xavier Thomas 15 Sep 20, 2022
MoCap-Solver: A Neural Solver for Optical Motion Capture Data

MoCap-Solver is a data-driven-based robust marker denoising method, which takes raw mocap markers as input and outputs corresponding clean markers and skeleton motions.

55 Dec 28, 2022
Cascaded Pyramid Network (CPN) based on Keras (Tensorflow backend)

ML2 Takehome Project Reimplementing the paper: Cascaded Pyramid Network for Multi-Person Pose Estimation Dataset The model uses the COCO dataset which

Vo Van Tu 1 Nov 22, 2021
Videocaptioning.pytorch - A simple implementation of video captioning

pytorch implementation of video captioning recommend installing pytorch and pyth

Yiyu Wang 2 Jan 01, 2022
Train Yolov4 using NBX-Jobs

yolov4-trainer-nbox Train Yolov4 using NBX-Jobs. Use the powerfull functionality available in nbox-SDK repo to train a tiny-Yolo v4 model on Pascal VO

Yash Bonde 1 Jan 12, 2022
A framework that allows people to write their own Rocket League bots.

YOU PROBABLY SHOULDN'T PULL THIS REPO Bot Makers Read This! If you just want to make a bot, you don't need to be here. Instead, start with one of thes

543 Dec 20, 2022
Learned model to estimate number of distinct values (NDV) of a population using a small sample.

Learned NDV estimator Learned model to estimate number of distinct values (NDV) of a population using a small sample. The model approximates the maxim

2 Nov 21, 2022
Code and data form the paper BERT Got a Date: Introducing Transformers to Temporal Tagging

BERT Got a Date: Introducing Transformers to Temporal Tagging Satya Almasian*, Dennis Aumiller*, and Michael Gertz Heidelberg University Contact us vi

54 Dec 04, 2022
Boosted neural network for tabular data

XBNet - Xtremely Boosted Network Boosted neural network for tabular data XBNet is an open source project which is built with PyTorch which tries to co

Tushar Sarkar 175 Jan 04, 2023
Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

El Bruno 3 Mar 30, 2022
GPU implementation of $k$-Nearest Neighbors and Shared-Nearest Neighbors

GPU implementation of kNN and SNN GPU implementation of $k$-Nearest Neighbors and Shared-Nearest Neighbors Supported by numba cuda and faiss library E

Hyeon Jeon 7 Nov 23, 2022