Code for "Training Neural Networks with Fixed Sparse Masks" (NeurIPS 2021).

Related tags

Deep LearningFISH
Overview

Fisher Induced Sparse uncHanging (FISH) Mask

This repo contains the code for Fisher Induced Sparse uncHanging (FISH) Mask training, from "Training Neural Networks with Fixed Sparse Masks" by Yi-Lin Sung, Varun Nair, and Colin Raffel. To appear in Neural Information Processing Systems (NeurIPS) 2021.

Abstract: During typical gradient-based training of deep neural networks, all of the model's parameters are updated at each iteration. Recent work has shown that it is possible to update only a small subset of the model's parameters during training, which can alleviate storage and communication requirements. In this paper, we show that it is possible to induce a fixed sparse mask on the model’s parameters that selects a subset to update over many iterations. Our method constructs the mask out of the parameters with the largest Fisher information as a simple approximation as to which parameters are most important for the task at hand. In experiments on parameter-efficient transfer learning and distributed training, we show that our approach matches or exceeds the performance of other methods for training with sparse updates while being more efficient in terms of memory usage and communication costs.

Setup

pip install transformers/.
pip install datasets torch==1.8.0 tqdm torchvision==0.9.0

FISH Mask: GLUE Experiments

Parameter-Efficient Transfer Learning

To run the FISH Mask on a GLUE dataset, code can be run with the following format:

$ bash transformers/examples/text-classification/scripts/run_sparse_updates.sh <dataset-name> <seed> <top_k_percentage> <num_samples_for_fisher>

An example command used to generate Table 1 in the paper is as follows, where all GLUE tasks are provided at a seed of 0 and a FISH mask sparsity of 0.5%.

$ bash transformers/examples/text-classification/scripts/run_sparse_updates.sh "qqp mnli rte cola stsb sst2 mrpc qnli" 0 0.005 1024

Distributed Training

To use the FISH mask on the GLUE tasks in a distributed setting, one can use the following command.

$ bash transformers/examples/text-classification/scripts/distributed_training.sh <dataset-name> <seed> <num_workers> <training_epochs> <gpu_id>

Note the <dataset-name> here can only contain one task, so an example command could be

$ bash transformers/examples/text-classification/scripts/distributed_training.sh "mnli" 0 2 3.5 0

FISH Mask: CIFAR10 Experiments

To run the FISH mask on CIFAR10, code can be run with the following format:

Distributed Training

$ bash cifar10-fast/scripts/distributed_training_fish.sh <num_samples_for_fisher> <top_k_percentage> <training_epochs> <worker_updates> <learning_rate> <num_workers>

For example, in the paper, we compute the FISH mask of the 0.5% sparsity level by 256 samples and distribute the job to 2 workers for a total of 50 epochs training. Then the command would be

$ bash cifar10-fast/scripts/distributed_training_fish.sh 256 0.005 50 2 0.4 2

Efficient Checkpointing

$ bash cifar10-fast/scripts/small_checkpoints_fish.sh <num_samples_for_fisher> <top_k_percentage> <training_epochs> <learning_rate> <fix_mask>

The hyperparameters are almost the same as distributed training. However, the <fix_mask> is to indicate to fix the mask or not, and a valid input is either 0 or 1 (1 means to fix the mask).

Replicating Results

Replicating each of the tables and figures present in the original paper can be done by running the following:

# Table 1 - Parameter Efficient Fine-Tuning on GLUE

$ bash transformers/examples/text-classification/scripts/run_table_1.sh
# Figure 2 - Mask Sparsity Ablation and Sample Ablation

$ bash transformers/examples/text-classification/scripts/run_figure_2.sh
# Table 2 - Distributed Training on GLUE

$ bash transformers/examples/text-classification/scripts/run_table_2.sh
# Table 3 - Distributed Training on CIFAR10

$ bash cifar10-fast/scripts/distributed_training.sh

# Table 4 - Efficient Checkpointing

$ bash cifar10-fast/scripts/small_checkpoints.sh

Notes

  • For reproduction of Diff Pruning results from Table 1, see code here.

Acknowledgements

We thank Yoon Kim, Michael Matena, and Demi Guo for helpful discussions.

Owner
Varun Nair
Hi! I'm a student at Duke University studying CS. I'm interested in researching AI/ML and its applications in medicine, transportation, & education.
Varun Nair
Official implementation of the ICCV 2021 paper "Joint Inductive and Transductive Learning for Video Object Segmentation"

JOINT This is the official implementation of Joint Inductive and Transductive learning for Video Object Segmentation, to appear in ICCV 2021. @inproce

Yunyao 35 Oct 16, 2022
ICNet for Real-Time Semantic Segmentation on High-Resolution Images, ECCV2018

ICNet for Real-Time Semantic Segmentation on High-Resolution Images by Hengshuang Zhao, Xiaojuan Qi, Xiaoyong Shen, Jianping Shi, Jiaya Jia, details a

Hengshuang Zhao 594 Dec 31, 2022
Tree-based Search Graph for Approximate Nearest Neighbor Search

TBSG: Tree-based Search Graph for Approximate Nearest Neighbor Search. TBSG is a graph-based algorithm for ANNS based on Cover Tree, which is also an

Fanxbin 2 Dec 27, 2022
Adversarial-autoencoders - Tensorflow implementation of Adversarial Autoencoders

Adversarial Autoencoders (AAE) Tensorflow implementation of Adversarial Autoencoders (ICLR 2016) Similar to variational autoencoder (VAE), AAE imposes

Qian Ge 236 Nov 13, 2022
Multimodal Descriptions of Social Concepts: Automatic Modeling and Detection of (Highly Abstract) Social Concepts evoked by Art Images

MUSCO - Multimodal Descriptions of Social Concepts Automatic Modeling of (Highly Abstract) Social Concepts evoked by Art Images This project aims to i

0 Aug 22, 2021
🔥3D-RecGAN in Tensorflow (ICCV Workshops 2017)

3D Object Reconstruction from a Single Depth View with Adversarial Learning Bo Yang, Hongkai Wen, Sen Wang, Ronald Clark, Andrew Markham, Niki Trigoni

Bo Yang 125 Nov 26, 2022
hySLAM is a hybrid SLAM/SfM system designed for mapping

HySLAM Overview hySLAM is a hybrid SLAM/SfM system designed for mapping. The system is based on ORB-SLAM2 with some modifications and refactoring. Raú

Brian Hopkinson 15 Oct 10, 2022
Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

33 Dec 01, 2022
ByteTrack: Multi-Object Tracking by Associating Every Detection Box

ByteTrack ByteTrack is a simple, fast and strong multi-object tracker. ByteTrack: Multi-Object Tracking by Associating Every Detection Box Yifu Zhang,

Yifu Zhang 2.9k Jan 04, 2023
RP-GAN: Stable GAN Training with Random Projections

RP-GAN: Stable GAN Training with Random Projections This repository contains a reference implementation of the algorithm described in the paper: Behna

Ayan Chakrabarti 20 Sep 18, 2021
Compartmental epidemic model to assess undocumented infections: applications to SARS-CoV-2 epidemics in Brazil - Datasets and Codes

Compartmental epidemic model to assess undocumented infections: applications to SARS-CoV-2 epidemics in Brazil - Datasets and Codes The codes for simu

1 Jan 12, 2022
A self-supervised learning framework for audio-visual speech

AV-HuBERT (Audio-Visual Hidden Unit BERT) Learning Audio-Visual Speech Representation by Masked Multimodal Cluster Prediction Robust Self-Supervised A

Meta Research 431 Jan 07, 2023
CR-FIQA: Face Image Quality Assessment by Learning Sample Relative Classifiability

This is the official repository of the paper: CR-FIQA: Face Image Quality Assessment by Learning Sample Relative Classifiability A private copy of the

Fadi Boutros 33 Dec 31, 2022
pytorch implementation of GPV-Pose

GPV-Pose Pytorch implementation of GPV-Pose: Category-level Object Pose Estimation via Geometry-guided Point-wise Voting. (link) UPDATE A new version

40 Dec 01, 2022
Trafffic prediction analysis using hybrid models - Machine Learning

Hybrid Machine learning Model Clone the Repository Create a new Directory as assests and download the model from the below link Model Link To Start th

1 Feb 08, 2022
pytorch implementation of openpose including Hand and Body Pose Estimation.

pytorch-openpose pytorch implementation of openpose including Body and Hand Pose Estimation, and the pytorch model is directly converted from openpose

Hzzone 1.4k Jan 07, 2023
[NeurIPS 2021] Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data

Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data (NeurIPS 2021) This repository will provide the official PyTorch implementa

Liming Jiang 238 Nov 25, 2022
Angora is a mutation-based fuzzer. The main goal of Angora is to increase branch coverage by solving path constraints without symbolic execution.

Angora Angora is a mutation-based coverage guided fuzzer. The main goal of Angora is to increase branch coverage by solving path constraints without s

833 Jan 07, 2023
Image process framework based on plugin like imagej, it is esay to glue with scipy.ndimage, scikit-image, opencv, simpleitk, mayavi...and any libraries based on numpy

Introduction ImagePy is an open source image processing framework written in Python. Its UI interface, image data structure and table data structure a

ImagePy 1.2k Dec 29, 2022
ADB-IP-ROTATION - Use your mobile phone to gain a temporary IP address using ADB and data tethering

ADB IP ROTATE This an Python script based on Android Debug Bridge (adb) shell sc

Dor Bismuth 2 Jul 12, 2022