Learning Neural Network Subspaces

Overview

Learning Neural Network Subspaces

Welcome to the codebase for Learning Neural Network Subspaces by Mitchell Wortsman, Maxwell Horton, Carlos Guestrin, Ali Farhadi, Mohammad Rastegari.

Figure1

Abstract

Recent observations have advanced our understanding of the neural network optimization landscape, revealing the existence of (1) paths of high accuracy containing diverse solutions and (2) wider minima offering improved performance. Previous methods observing diverse paths require multiple training runs. In contrast we aim to leverage both property (1) and (2) with a single method and in a single training run. With a similar computational cost as training one model, we learn lines, curves, and simplexes of high-accuracy neural networks. These neural network subspaces contain diverse solutions that can be ensembled, approaching the ensemble performance of independently trained networks without the training cost. Moreover, using the subspace midpoint boosts accuracy, calibration, and robustness to label noise, outperforming Stochastic Weight Averaging.

Code Overview

In this repository we walk through learning neural network subspaces with PyTorch. We will ground the discussion with learning a line of neural networks. In our code, a line is defined by endpoints weight and weight1 and a point on the line is given by w = (1 - alpha) * weight + alpha * weight1 for some alpha in [0,1].

Algorithm 1 (see paper) works as follows:

  1. weight and weight1 are initialized independently.
  2. For each batch data, targets, alpha is chosen uniformly from [0,1] and the weights w = (1 - alpha) * weight + alpha * weight1 are used in the forward pass.
  3. The regularization term is computed (see Eq. 3).
  4. With loss.backward() and optimizer.step() the endpoints weight and weight1 are updated.

Instead of using a regular nn.Conv2d we instead use a SubspaceConv (found in modes/modules.py).

class SubspaceConv(nn.Conv2d):
    def forward(self, x):
        w = self.get_weight()
        x = F.conv2d(
            x,
            w,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return x

For each subspace type (lines, curves, and simplexes) the function get_weight must be implemented. For lines we use:

class TwoParamConv(SubspaceConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight1 = nn.Parameter(torch.zeros_like(self.weight))

    def initialize(self, initialize_fn):
        initialize_fn(self.weight1)

class LinesConv(TwoParamConv):
    def get_weight(self):
        w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
        return w

Note that the other endpoint weight is instantiated and initialized by nn.Conv2d. Also note that there is an equivalent implementation for batch norm layers also found in modes/modules.py.

Now we turn to the training logic which appears in trainers/train_one_dim_subspaces.py. In the snippet below we assume we are not training with the layerwise variant (args.layerwise = False) and we are drawing only one sample from the subspace (args.num_samples = 1).

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(args.device), target.to(args.device)

    alpha = np.random.uniform(0, 1)
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
            setattr(m, f"alpha", alpha)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)

All that's left is to compute the regularization term and call backward. For lines, this is given by the snippet below.

    num = 0.0
    norm = 0.0
    norm1 = 0.0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            num += (self.weight * self.weight1).sum()
            norm += self.weight.pow(2).sum()
            norm1 += self.weight1.pow(2).sum()
    loss += args.beta * (num.pow(2) / (norm * norm1))

    loss.backward()

    optimizer.step()

Training Lines, Curves, and Simplexes

We now walkthrough generating the plots in Figures 4 and 5 of the paper. Before running code please install PyTorch and Tensorboard (for making plots you will also need tex on your computer). Note that this repository differs from that used to generate the figures in the paper, as the latter leveraged Apple's internal tools. Accordingly there may be some bugs and we encourage you to submit an issue or send an email if you run into any problems.

In this example walkthrough we consider TinyImageNet, which we download to ~/data using a script such as this. To run standard training and ensemble the trained models, use the following command:

python experiment_configs/tinyimagenet/ensembles/train_ensemble_members.py
python experiment_configs/tinyimagenet/ensembles/eval_ensembles.py

Note that if your data is not in ~/data please change the paths in these experiment configs. Logs and checkpoints be saved in learning-subspaces-results, although this path can also be changed.

For one dimensional subspaces, use the following command to train:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_curves.py

To evaluate (i.e. generate the data for Figure 4) use:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_curves.py

We recommend looking at the experiment config files before running, which can be modified to change the type of model, number of random seeds. The default in these configs is 2 random seeds.

Analogously, to train simplexes use:

python experiment_configs/tinyimagenet/simplexes/train_simplexes.py
python experiment_configs/tinyimagenet/simplexes/train_simplexes_layerwise.py

For generating plots like those in Figure 4 and 5 use:

python analyze_results/tinyimagenet/one_dimensional_subspaces.py
python analyze_results/tinyimagenet/simplexes.py

Equivalent configs exist for other datasets, and the configs can be modified to add label noise, experiment with other models, and more. Also, if there is any functionality missing from this repository that you would like please also submit an issue.

Bibtex

@article{wortsman2021learning,
  title={Learning Neural Network Subspaces},
  author={Wortsman, Mitchell and Horton, Maxwell and Guestrin, Carlos and Farhadi, Ali and Rastegari, Mohammad},
  journal={arXiv preprint arXiv:2102.10472},
  year={2021}
}
Owner
Apple
Apple
PowerGridworld: A Framework for Multi-Agent Reinforcement Learning in Power Systems

PowerGridworld provides users with a lightweight, modular, and customizable framework for creating power-systems-focused, multi-agent Gym environments that readily integrate with existing training fr

National Renewable Energy Laboratory 37 Dec 17, 2022
Codes for our paper The Stem Cell Hypothesis: Dilemma behind Multi-Task Learning with Transformer Encoders published to EMNLP 2021.

The Stem Cell Hypothesis Codes for our paper The Stem Cell Hypothesis: Dilemma behind Multi-Task Learning with Transformer Encoders published to EMNLP

Emory NLP 5 Jul 08, 2022
Procedural 3D data generation pipeline for architecture

Synthetic Dataset Generator Authors: Stanislava Fedorova Alberto Tono Meher Shashwat Nigam Jiayao Zhang Amirhossein Ahmadnia Cecilia bolognesi Dominik

Computational Design Institute 49 Nov 25, 2022
Using VideoBERT to tackle video prediction

VideoBERT This repo reproduces the results of VideoBERT (https://arxiv.org/pdf/1904.01766.pdf). Inspiration was taken from https://github.com/MDSKUL/M

75 Dec 14, 2022
The codes reproduce the figures and statistics in the paper, "Controlling for multiple covariates," by Mark Tygert.

The accompanying codes reproduce all figures and statistics presented in "Controlling for multiple covariates" by Mark Tygert. This repository also pr

Meta Research 1 Dec 02, 2021
Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces(ICML 2021)

Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces(ICML 2021) This repository contains the code

149 Dec 15, 2022
QuadTree Attention for Vision Transformers (ICLR2022)

This repository contains codes for quadtree attention. This repo contains codes for feature matching, image classficiation, object detection and seman

tangshitao 222 Dec 28, 2022
Structured Edge Detection Toolbox

################################################################### # # # Structure

Piotr Dollar 779 Jan 02, 2023
CodeContests is a competitive programming dataset for machine-learning

CodeContests CodeContests is a competitive programming dataset for machine-learning. This dataset was used when training AlphaCode. It consists of pro

DeepMind 1.6k Jan 08, 2023
Test-Time Personalization with a Transformer for Human Pose Estimation, NeurIPS 2021

Transforming Self-Supervision in Test Time for Personalizing Human Pose Estimation This is an official implementation of the NeurIPS 2021 paper: Trans

41 Nov 28, 2022
Semiconductor Machine learning project

Wafer Fault Detection Problem Statement: Wafer (In electronics), also called a slice or substrate, is a thin slice of semiconductor, such as a crystal

kunal suryawanshi 1 Jan 15, 2022
A simple rest api that classifies pneumonia infection weather it is Normal, Pneumonia Virus or Pneumonia Bacteria from a chest-x-ray image.

This is a simple rest api that classifies pneumonia infection weather it is Normal, Pneumonia Virus or Pneumonia Bacteria from a chest-x-ray image.

crispengari 3 Jan 08, 2022
Improving 3D Object Detection with Channel-wise Transformer

"Improving 3D Object Detection with Channel-wise Transformer" Thanks for the OpenPCDet, this implementation of the CT3D is mainly based on the pcdet v

Hualian Sheng 107 Dec 20, 2022
v objective diffusion inference code for PyTorch.

v-diffusion-pytorch v objective diffusion inference code for PyTorch, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman). The

Katherine Crowson 635 Dec 30, 2022
Mapping Conditional Distributions for Domain Adaptation Under Generalized Target Shift

This repository contains the official code of OSTAR in "Mapping Conditional Distributions for Domain Adaptation Under Generalized Target Shift" (ICLR 2022).

Matthieu Kirchmeyer 5 Dec 06, 2022
CVPRW 2021: How to calibrate your event camera

E2Calib: How to Calibrate Your Event Camera This repository contains code that implements video reconstruction from event data for calibration as desc

Robotics and Perception Group 104 Nov 16, 2022
Recovering Brain Structure Network Using Functional Connectivity

Recovering-Brain-Structure-Network-Using-Functional-Connectivity Framework: Papers: This repository provides a PyTorch implementation of the models ad

5 Nov 30, 2022
Accuracy Aligned. Concise Implementation of Swin Transformer

Accuracy Aligned. Concise Implementation of Swin Transformer This repository contains the implementation of Swin Transformer, and the training codes o

FengWang 77 Dec 16, 2022
Implementation of the paper "Language-agnostic representation learning of source code from structure and context".

Code Transformer This is an official PyTorch implementation of the CodeTransformer model proposed in: D. Zügner, T. Kirschstein, M. Catasta, J. Leskov

Daniel Zügner 131 Dec 13, 2022
Pre-trained NFNets with 99% of the accuracy of the official paper

NFNet Pytorch Implementation This repo contains pretrained NFNet models F0-F6 with high ImageNet accuracy from the paper High-Performance Large-Scale

Benjamin Schmidt 133 Dec 09, 2022