Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)

Overview

Neuron Merging: Compensating for Pruned Neurons

Pytorch implementation of Neuron Merging: Compensating for Pruned Neurons, accepted at 34th Conference on Neural Information Processing Systems (NeurIPS 2020).

Requirements

To install requirements:

conda env create -f ./environment.yml

Python environment & main libraries:

  • python 3.8
  • pytorch 1.5.0
  • scikit-learn 0.22.1
  • torchvision 0.6.0

LeNet-300-100

To test LeNet-300-100 model on FashionMNIST, run:

bash scripts/LeNet_300_100_FashionMNIST.sh -t [model type] -c [criterion] -r [pruning ratio]

You can use three arguments for this script:

  • model type: original | prune | merge
  • pruning criterion : l1-norm | l2-norm | l2-GM
  • pruning ratio : 0.0 ~ 1.0

For example, to test the model after pruning 50% of the neurons with $l_1$-norm criterion, run:

bash scripts/LeNet_300_100_FashionMNIST.sh -t prune -c l1-norm -r 0.5

To test the model after merging , run:

bash scripts/LeNet_300_100_FashionMNIST.sh -t merge -c l1-norm -r 0.5

VGG-16

To test VGG-16 model on CIFAR-10, run:

bash scripts/VGG16_CIFAR10.sh -t [model type] -c [criterion]

You can use two arguments for this script

  • model type: original | prune | merge
  • pruning criterion: l1-norm | l2-norm | l2-GM

As a pretrained model on CIFAR-100 is not included, you must train it first. To train VGG-16 on CIFAR-100, run:

bash scripts/VGG16_CIFAR100_train.sh

All the hyperparameters are as described in the supplementary material.

After training, to test VGG-16 model on CIFAR-100, run:

bash scripts/VGG16_CIFAR100.sh -t [model type] -c [criterion]

You can use two arguments for this script

  • model type: original | prune | merge
  • pruning criterion: l1-norm | l2-norm | l2-GM

ResNet

To test ResNet-56 model on CIFAR-10, run:

bash scripts/ResNet56_CIFAR10.sh -t [model type] -c [criterion] -r [pruning ratio]

You can use three arguments for this script

  • model type: original | prune | merge
  • pruning method : l1-norm | l2-norm | l2-GM
  • pruning ratio : 0.0 ~ 1.0

To test WideResNet-40-4 model on CIFAR-10, run:

bash scripts/WideResNet_40_4_CIFAR10.sh -t [model type] -c [criterion] -r [pruning ratio]

You can use three arguments for this script

  • model type: original | prune | merge
  • pruning method : l1-norm | l2-norm | l2-GM
  • pruning ratio : 0.0 ~ 1.0

Results

Our model achieves the following performance on (without fine-tuning) :

Image classification of LeNet-300-100 on FashionMNIST

Baseline Accuracy : 89.80%

Pruning Ratio Prune ($l_1$-norm) Merge
50% 88.40% 88.69%
60% 85.17% 86.92%
70% 71.26% 82.75%
80% 66.76 80.02%

Image classification of VGG-16 on CIFAR-10

Baseline Accuracy : 93.70%

Criterion Prune Merge
$l_1$-norm 88.70% 93.16%
$l_2$-norm 89.14% 93.16%
$l_2$-GM 87.85% 93.10%

Citation

@inproceedings{kim2020merging,
  title     = {Neuron Merging: Compensating for Pruned Neurons},
  author    = {Kim, Woojeong and Kim, Suhyun and Park, Mincheol and Jeon, Geonseok},
  booktitle = {Advances in Neural Information Processing Systems 33},
  year      = {2020}
}
Owner
Woojeong Kim
Woojeong Kim
Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting

Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting Note: You can find here the accompanying seq2seq RNN forecas

Guillaume Chevalier 1k Dec 25, 2022
Airborne magnetic data of the Osborne Mine and Lightning Creek sill complex, Australia

Osborne Mine, Australia - Airborne total-field magnetic anomaly This is a section of a survey acquired in 1990 by the Queensland Government, Australia

Fatiando a Terra Datasets 1 Jan 21, 2022
Source code related to the article submitted to the International Conference on Computational Science ICCS 2022 in London

POTHER: Patch-Voted Deep Learning-based Chest X-ray Bias Analysis for COVID-19 Detection Source code related to the article submitted to the Internati

Tomasz Szczepański 1 Apr 29, 2022
This repository accompanies the ACM TOIS paper "What can I cook with these ingredients?" - Understanding cooking-related information needs in conversational search

In this repository you find data that has been gathered when conducting in-situ experiments in a conversational cooking setting. These data include tr

6 Sep 22, 2022
This repository contains the official implementation code of the paper Transformer-based Feature Reconstruction Network for Robust Multimodal Sentiment Analysis

This repository contains the official implementation code of the paper Transformer-based Feature Reconstruction Network for Robust Multimodal Sentiment Analysis, accepted at ACMMM 2021.

Ziqi Yuan 10 Sep 30, 2022
CBKH: The Cornell Biomedical Knowledge Hub

Cornell Biomedical Knowledge Hub (CBKH) CBKG integrates data from 18 publicly available biomedical databases. The current version of CBKG contains a t

44 Dec 21, 2022
Code for Transformers Solve Limited Receptive Field for Monocular Depth Prediction

Official PyTorch code for Transformers Solve Limited Receptive Field for Monocular Depth Prediction. Guanglei Yang, Hao Tang, Mingli Ding, Nicu Sebe,

stanley 152 Dec 16, 2022
CVPR 2020 oral paper: Overcoming Classifier Imbalance for Long-tail Object Detection with Balanced Group Softmax.

Overcoming Classifier Imbalance for Long-tail Object Detection with Balanced Group Softmax ⚠️ Latest: Current repo is a complete version. But we delet

FishYuLi 341 Dec 23, 2022
Hippocampal segmentation using the UNet network for each axis

Hipposeg Hippocampal segmentation using the UNet network for each axis, inspired by https://github.com/MICLab-Unicamp/e2dhipseg Red: False Positive Gr

Juan Carlos Aguirre Arango 0 Sep 02, 2021
Mmdetection3d Noted - MMDetection3D is an open source object detection toolbox based on PyTorch

MMDetection3D is an open source object detection toolbox based on PyTorch

Jiangjingwen 13 Jan 06, 2023
Detectron2 for Document Layout Analysis

Detectron2 trained on PubLayNet dataset This repo contains the training configurations, code and trained models trained on PubLayNet dataset using Det

Himanshu 163 Nov 21, 2022
113 Nov 28, 2022
Source code for From Stars to Subgraphs

GNNAsKernel Official code for From Stars to Subgraphs: Uplifting Any GNN with Local Structure Awareness Visualizations GNN-AK(+) GNN-AK(+) with Subgra

44 Dec 19, 2022
Tianshou - An elegant PyTorch deep reinforcement learning library.

Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on

Tsinghua Machine Learning Group 5.5k Jan 05, 2023
Perturb-and-max-product: Sampling and learning in discrete energy-based models

Perturb-and-max-product: Sampling and learning in discrete energy-based models This repo contains code for reproducing the results in the paper Pertur

Vicarious 2 Mar 14, 2022
AI-based, context-driven network device ranking

Batea A batea is a large shallow pan of wood or iron traditionally used by gold prospectors for washing sand and gravel to recover gold nuggets. Batea

Secureworks Taegis VDR 269 Nov 26, 2022
Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch

Omninet - Pytorch Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch. The authors propose that we should be atte

Phil Wang 48 Nov 21, 2022
Implementing Vision Transformer (ViT) in PyTorch

Lightning-Hydra-Template A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥 Click on Use this template to initialize new re

2 Dec 24, 2021
Extreme Rotation Estimation using Dense Correlation Volumes

Extreme Rotation Estimation using Dense Correlation Volumes This repository contains a PyTorch implementation of the paper: Extreme Rotation Estimatio

Ruojin Cai 29 Nov 18, 2022
Code release for "Masked-attention Mask Transformer for Universal Image Segmentation"

Mask2Former: Masked-attention Mask Transformer for Universal Image Segmentation Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Ro

Meta Research 1.2k Jan 02, 2023