Learning Neural Network Subspaces

Overview

Learning Neural Network Subspaces

Welcome to the codebase for Learning Neural Network Subspaces by Mitchell Wortsman, Maxwell Horton, Carlos Guestrin, Ali Farhadi, Mohammad Rastegari.

Figure1

Abstract

Recent observations have advanced our understanding of the neural network optimization landscape, revealing the existence of (1) paths of high accuracy containing diverse solutions and (2) wider minima offering improved performance. Previous methods observing diverse paths require multiple training runs. In contrast we aim to leverage both property (1) and (2) with a single method and in a single training run. With a similar computational cost as training one model, we learn lines, curves, and simplexes of high-accuracy neural networks. These neural network subspaces contain diverse solutions that can be ensembled, approaching the ensemble performance of independently trained networks without the training cost. Moreover, using the subspace midpoint boosts accuracy, calibration, and robustness to label noise, outperforming Stochastic Weight Averaging.

Code Overview

In this repository we walk through learning neural network subspaces with PyTorch. We will ground the discussion with learning a line of neural networks. In our code, a line is defined by endpoints weight and weight1 and a point on the line is given by w = (1 - alpha) * weight + alpha * weight1 for some alpha in [0,1].

Algorithm 1 (see paper) works as follows:

  1. weight and weight1 are initialized independently.
  2. For each batch data, targets, alpha is chosen uniformly from [0,1] and the weights w = (1 - alpha) * weight + alpha * weight1 are used in the forward pass.
  3. The regularization term is computed (see Eq. 3).
  4. With loss.backward() and optimizer.step() the endpoints weight and weight1 are updated.

Instead of using a regular nn.Conv2d we instead use a SubspaceConv (found in modes/modules.py).

class SubspaceConv(nn.Conv2d):
    def forward(self, x):
        w = self.get_weight()
        x = F.conv2d(
            x,
            w,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return x

For each subspace type (lines, curves, and simplexes) the function get_weight must be implemented. For lines we use:

class TwoParamConv(SubspaceConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight1 = nn.Parameter(torch.zeros_like(self.weight))

    def initialize(self, initialize_fn):
        initialize_fn(self.weight1)

class LinesConv(TwoParamConv):
    def get_weight(self):
        w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
        return w

Note that the other endpoint weight is instantiated and initialized by nn.Conv2d. Also note that there is an equivalent implementation for batch norm layers also found in modes/modules.py.

Now we turn to the training logic which appears in trainers/train_one_dim_subspaces.py. In the snippet below we assume we are not training with the layerwise variant (args.layerwise = False) and we are drawing only one sample from the subspace (args.num_samples = 1).

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(args.device), target.to(args.device)

    alpha = np.random.uniform(0, 1)
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
            setattr(m, f"alpha", alpha)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)

All that's left is to compute the regularization term and call backward. For lines, this is given by the snippet below.

    num = 0.0
    norm = 0.0
    norm1 = 0.0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            num += (self.weight * self.weight1).sum()
            norm += self.weight.pow(2).sum()
            norm1 += self.weight1.pow(2).sum()
    loss += args.beta * (num.pow(2) / (norm * norm1))

    loss.backward()

    optimizer.step()

Training Lines, Curves, and Simplexes

We now walkthrough generating the plots in Figures 4 and 5 of the paper. Before running code please install PyTorch and Tensorboard (for making plots you will also need tex on your computer). Note that this repository differs from that used to generate the figures in the paper, as the latter leveraged Apple's internal tools. Accordingly there may be some bugs and we encourage you to submit an issue or send an email if you run into any problems.

In this example walkthrough we consider TinyImageNet, which we download to ~/data using a script such as this. To run standard training and ensemble the trained models, use the following command:

python experiment_configs/tinyimagenet/ensembles/train_ensemble_members.py
python experiment_configs/tinyimagenet/ensembles/eval_ensembles.py

Note that if your data is not in ~/data please change the paths in these experiment configs. Logs and checkpoints be saved in learning-subspaces-results, although this path can also be changed.

For one dimensional subspaces, use the following command to train:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_curves.py

To evaluate (i.e. generate the data for Figure 4) use:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_curves.py

We recommend looking at the experiment config files before running, which can be modified to change the type of model, number of random seeds. The default in these configs is 2 random seeds.

Analogously, to train simplexes use:

python experiment_configs/tinyimagenet/simplexes/train_simplexes.py
python experiment_configs/tinyimagenet/simplexes/train_simplexes_layerwise.py

For generating plots like those in Figure 4 and 5 use:

python analyze_results/tinyimagenet/one_dimensional_subspaces.py
python analyze_results/tinyimagenet/simplexes.py

Equivalent configs exist for other datasets, and the configs can be modified to add label noise, experiment with other models, and more. Also, if there is any functionality missing from this repository that you would like please also submit an issue.

Bibtex

@article{wortsman2021learning,
  title={Learning Neural Network Subspaces},
  author={Wortsman, Mitchell and Horton, Maxwell and Guestrin, Carlos and Farhadi, Ali and Rastegari, Mohammad},
  journal={arXiv preprint arXiv:2102.10472},
  year={2021}
}
Owner
Apple
Apple
Music Source Separation; Train & Eval & Inference piplines and pretrained models we used for 2021 ISMIR MDX Challenge.

Introduction 1. Usage (For MSS) 1.1 Prepare running environment 1.2 Use pretrained model 1.3 Train new MSS models from scratch 1.3.1 How to train 1.3.

Leo 100 Dec 25, 2022
SANet: A Slice-Aware Network for Pulmonary Nodule Detection

SANet: A Slice-Aware Network for Pulmonary Nodule Detection This paper (SANet) has been accepted and early accessed in IEEE TPAMI 2021. This code and

Jie Mei 39 Dec 17, 2022
Generic ecosystem for feature extraction from aerial and satellite imagery

Note: Robosat is neither maintained not actively developed any longer by Mapbox. See this issue. The main developers (@daniel-j-h, @bkowshik) are no l

Mapbox 1.9k Jan 06, 2023
The code for 'Deep Residual Fourier Transformation for Single Image Deblurring'

Deep Residual Fourier Transformation for Single Image Deblurring Xintian Mao, Yiming Liu, Wei Shen, Qingli Li and Yan Wang News 2021.12.5 Release Deep

145 Jan 05, 2023
git git《Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking》(CVPR 2021) GitHub:git2] 《Masksembles for Uncertainty Estimation》(CVPR 2021) GitHub:git3]

Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking Ning Wang, Wengang Zhou, Jie Wang, and Houqiang Li Accepted by CVPR

NingWang 236 Dec 22, 2022
Largest list of models for Core ML (for iOS 11+)

Since iOS 11, Apple released Core ML framework to help developers integrate machine learning models into applications. The official documentation We'v

Kedan Li 5.6k Jan 08, 2023
A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

張致強 14 Dec 02, 2022
Mask-invariant Face Recognition through Template-level Knowledge Distillation

Mask-invariant Face Recognition through Template-level Knowledge Distillation This is the official repository of "Mask-invariant Face Recognition thro

Fadi Boutros 35 Dec 06, 2022
A TensorFlow implementation of Neural Program Synthesis from Diverse Demonstration Videos

ViZDoom http://vizdoom.cs.put.edu.pl ViZDoom allows developing AI bots that play Doom using only the visual information (the screen buffer). It is pri

Hyeonwoo Noh 1 Aug 19, 2020
A time series processing library

Timeseria Timeseria is a time series processing library which aims at making it easy to handle time series data and to build statistical and machine l

Stefano Alberto Russo 11 Aug 08, 2022
OpenDelta - An Open-Source Framework for Paramter Efficient Tuning.

OpenDelta is a toolkit for parameter efficient methods (we dub it as delta tuning), by which users could flexibly assign (or add) a small amount parameters to update while keeping the most paramters

THUNLP 386 Dec 26, 2022
Detecting drunk people through thermal images using Deep Learning (CNN)

Drunk Detection CNN Detecting drunk people through thermal images using Deep Learning (CNN) Dataset We used thermal images provided by Electronics Lab

Giacomo Ferretti 3 Oct 27, 2022
The code for Bi-Mix: Bidirectional Mixing for Domain Adaptive Nighttime Semantic Segmentation

BiMix The code for Bi-Mix: Bidirectional Mixing for Domain Adaptive Nighttime Semantic Segmentation arxiv Framework: visualization results: Requiremen

stanley 18 Sep 18, 2022
Pytorch code for "State-only Imitation with Transition Dynamics Mismatch" (ICLR 2020)

This repo contains code for our paper State-only Imitation with Transition Dynamics Mismatch published at ICLR 2020. The code heavily uses the RL mach

20 Sep 08, 2022
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Pawel Dziemiach 1 Dec 19, 2021
Physics-Informed Neural Networks (PINN) and Deep BSDE Solvers of Differential Equations for Scientific Machine Learning (SciML) accelerated simulation

NeuralPDE NeuralPDE.jl is a solver package which consists of neural network solvers for partial differential equations using scientific machine learni

SciML Open Source Scientific Machine Learning 680 Jan 02, 2023
Realtime segmentation with ENet, the fast and accurate segmentation net.

Enet This is a realtime segmentation net with almost 22 fps on GTX1080 ti, and the model size is very small with only 28M. This repo contains the infe

JinTian 14 Aug 30, 2022
A standard framework for modelling Deep Learning Models for tabular data

PyTorch Tabular aims to make Deep Learning with Tabular data easy and accessible to real-world cases and research alike.

801 Jan 08, 2023
Unofficial PyTorch implementation of Neural Additive Models (NAM) by Agarwal, et al.

nam-pytorch Unofficial PyTorch implementation of Neural Additive Models (NAM) by Agarwal, et al. [abs, pdf] Installation You can access nam-pytorch vi

Rishabh Anand 11 Mar 14, 2022