Learning Neural Network Subspaces

Overview

Learning Neural Network Subspaces

Welcome to the codebase for Learning Neural Network Subspaces by Mitchell Wortsman, Maxwell Horton, Carlos Guestrin, Ali Farhadi, Mohammad Rastegari.

Figure1

Abstract

Recent observations have advanced our understanding of the neural network optimization landscape, revealing the existence of (1) paths of high accuracy containing diverse solutions and (2) wider minima offering improved performance. Previous methods observing diverse paths require multiple training runs. In contrast we aim to leverage both property (1) and (2) with a single method and in a single training run. With a similar computational cost as training one model, we learn lines, curves, and simplexes of high-accuracy neural networks. These neural network subspaces contain diverse solutions that can be ensembled, approaching the ensemble performance of independently trained networks without the training cost. Moreover, using the subspace midpoint boosts accuracy, calibration, and robustness to label noise, outperforming Stochastic Weight Averaging.

Code Overview

In this repository we walk through learning neural network subspaces with PyTorch. We will ground the discussion with learning a line of neural networks. In our code, a line is defined by endpoints weight and weight1 and a point on the line is given by w = (1 - alpha) * weight + alpha * weight1 for some alpha in [0,1].

Algorithm 1 (see paper) works as follows:

  1. weight and weight1 are initialized independently.
  2. For each batch data, targets, alpha is chosen uniformly from [0,1] and the weights w = (1 - alpha) * weight + alpha * weight1 are used in the forward pass.
  3. The regularization term is computed (see Eq. 3).
  4. With loss.backward() and optimizer.step() the endpoints weight and weight1 are updated.

Instead of using a regular nn.Conv2d we instead use a SubspaceConv (found in modes/modules.py).

class SubspaceConv(nn.Conv2d):
    def forward(self, x):
        w = self.get_weight()
        x = F.conv2d(
            x,
            w,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return x

For each subspace type (lines, curves, and simplexes) the function get_weight must be implemented. For lines we use:

class TwoParamConv(SubspaceConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight1 = nn.Parameter(torch.zeros_like(self.weight))

    def initialize(self, initialize_fn):
        initialize_fn(self.weight1)

class LinesConv(TwoParamConv):
    def get_weight(self):
        w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
        return w

Note that the other endpoint weight is instantiated and initialized by nn.Conv2d. Also note that there is an equivalent implementation for batch norm layers also found in modes/modules.py.

Now we turn to the training logic which appears in trainers/train_one_dim_subspaces.py. In the snippet below we assume we are not training with the layerwise variant (args.layerwise = False) and we are drawing only one sample from the subspace (args.num_samples = 1).

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(args.device), target.to(args.device)

    alpha = np.random.uniform(0, 1)
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
            setattr(m, f"alpha", alpha)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)

All that's left is to compute the regularization term and call backward. For lines, this is given by the snippet below.

    num = 0.0
    norm = 0.0
    norm1 = 0.0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            num += (self.weight * self.weight1).sum()
            norm += self.weight.pow(2).sum()
            norm1 += self.weight1.pow(2).sum()
    loss += args.beta * (num.pow(2) / (norm * norm1))

    loss.backward()

    optimizer.step()

Training Lines, Curves, and Simplexes

We now walkthrough generating the plots in Figures 4 and 5 of the paper. Before running code please install PyTorch and Tensorboard (for making plots you will also need tex on your computer). Note that this repository differs from that used to generate the figures in the paper, as the latter leveraged Apple's internal tools. Accordingly there may be some bugs and we encourage you to submit an issue or send an email if you run into any problems.

In this example walkthrough we consider TinyImageNet, which we download to ~/data using a script such as this. To run standard training and ensemble the trained models, use the following command:

python experiment_configs/tinyimagenet/ensembles/train_ensemble_members.py
python experiment_configs/tinyimagenet/ensembles/eval_ensembles.py

Note that if your data is not in ~/data please change the paths in these experiment configs. Logs and checkpoints be saved in learning-subspaces-results, although this path can also be changed.

For one dimensional subspaces, use the following command to train:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_curves.py

To evaluate (i.e. generate the data for Figure 4) use:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_curves.py

We recommend looking at the experiment config files before running, which can be modified to change the type of model, number of random seeds. The default in these configs is 2 random seeds.

Analogously, to train simplexes use:

python experiment_configs/tinyimagenet/simplexes/train_simplexes.py
python experiment_configs/tinyimagenet/simplexes/train_simplexes_layerwise.py

For generating plots like those in Figure 4 and 5 use:

python analyze_results/tinyimagenet/one_dimensional_subspaces.py
python analyze_results/tinyimagenet/simplexes.py

Equivalent configs exist for other datasets, and the configs can be modified to add label noise, experiment with other models, and more. Also, if there is any functionality missing from this repository that you would like please also submit an issue.

Bibtex

@article{wortsman2021learning,
  title={Learning Neural Network Subspaces},
  author={Wortsman, Mitchell and Horton, Maxwell and Guestrin, Carlos and Farhadi, Ali and Rastegari, Mohammad},
  journal={arXiv preprint arXiv:2102.10472},
  year={2021}
}
Owner
Apple
Apple
Unsupervised Foreground Extraction via Deep Region Competition

Unsupervised Foreground Extraction via Deep Region Competition [Paper] [Code] The official code repository for NeurIPS 2021 paper "Unsupervised Foregr

28 Nov 06, 2022
This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds

LiDARTag Overview This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds (PDF)(arXiv). This wo

University of Michigan Dynamic Legged Locomotion Robotics Lab 159 Dec 21, 2022
Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

FPT_data_centric_competition - Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

Pham Viet Hoang (Harry) 2 Oct 30, 2022
Semi-Supervised Semantic Segmentation with Pixel-Level Contrastive Learning from a Class-wise Memory Bank

This repository provides the official code for replicating experiments from the paper: Semi-Supervised Semantic Segmentation with Pixel-Level Contrast

Iñigo Alonso Ruiz 58 Dec 15, 2022
Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images

Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images In this paper, we present an effective Dynamic Enhancement Anchor

13 Dec 09, 2022
This library provides an abstraction to perform Model Versioning using Weight & Biases.

Description This library provides an abstraction to perform Model Versioning using Weight & Biases. Features Version a new trained model Promote a mod

Hector Lopez Almazan 2 Jan 28, 2022
Code for the paper "Offline Reinforcement Learning as One Big Sequence Modeling Problem"

Trajectory Transformer Code release for Offline Reinforcement Learning as One Big Sequence Modeling Problem. Installation All python dependencies are

Michael Janner 266 Dec 27, 2022
Fast and Easy Infinite Neural Networks in Python

Neural Tangents ICLR 2020 Video | Paper | Quickstart | Install guide | Reference docs | Release notes Overview Neural Tangents is a high-level neural

Google 1.9k Jan 09, 2023
a spacial-temporal pattern detection system for home automation

Argos a spacial-temporal pattern detection system for home automation. Based on OpenCV and Tensorflow, can run on raspberry pi and notify HomeAssistan

Angad Singh 133 Jan 05, 2023
Geometric Deep Learning Extension Library for PyTorch

Documentation | Paper | Colab Notebooks | External Resources | OGB Examples PyTorch Geometric (PyG) is a geometric deep learning extension library for

Matthias Fey 16.5k Jan 08, 2023
Speech Enhancement Generative Adversarial Network Based on Asymmetric AutoEncoder

ASEGAN: Speech Enhancement Generative Adversarial Network Based on Asymmetric AutoEncoder 中文版简介 Readme with English Version 介绍 基于SEGAN模型的改进版本,使用自主设计的非

Nitin 53 Nov 17, 2022
Bio-Computing Platform Featuring Large-Scale Representation Learning and Multi-Task Deep Learning “螺旋桨”生物计算工具集

English | 简体中文 Latest News 2021.10.25 Paper "Docking-based Virtual Screening with Multi-Task Learning" is accepted by BIBM 2021. 2021.07.29 PaddleHeli

633 Jan 04, 2023
Expressive Body Capture: 3D Hands, Face, and Body from a Single Image

Expressive Body Capture: 3D Hands, Face, and Body from a Single Image [Project Page] [Paper] [Supp. Mat.] Table of Contents License Description Fittin

Vassilis Choutas 1.3k Jan 07, 2023
Implementation of Neural Distance Embeddings for Biological Sequences (NeuroSEED) in PyTorch

Neural Distance Embeddings for Biological Sequences Official implementation of Neural Distance Embeddings for Biological Sequences (NeuroSEED) in PyTo

Gabriele Corso 56 Dec 23, 2022
text_recognition_toolbox: The reimplementation of a series of classical scene text recognition papers with Pytorch in a uniform way.

text recognition toolbox 1. 项目介绍 该项目是基于pytorch深度学习框架,以统一的改写方式实现了以下6篇经典的文字识别论文,论文的详情如下。该项目会持续进行更新,欢迎大家提出问题以及对代码进行贡献。 模型 论文标题 发表年份 模型方法划分 CRNN 《An End-t

168 Dec 24, 2022
Implementation of Change-Based Exploration Transfer (C-BET)

Implementation of Change-Based Exploration Transfer (C-BET), as presented in Interesting Object, Curious Agent: Learning Task-Agnostic Exploration.

Simone Parisi 29 Dec 04, 2022
From Perceptron model to Deep Neural Network from scratch in Python.

Neural-Network-Basics Aim of this Repository: From Perceptron model to Deep Neural Network (from scratch) in Python. ** Currently working on a basic N

Aditya Kahol 1 Jan 14, 2022
Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection

Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection Main requirements torch = 1.0 torchvision = 0.2.0 Python 3 Environm

15 Apr 04, 2022
Mortgage-loan-prediction - Show how to perform advanced Analytics and Machine Learning in Python using a full complement of PyData utilities

Mortgage-loan-prediction - Show how to perform advanced Analytics and Machine Learning in Python using a full complement of PyData utilities

Deepak Nandwani 1 Dec 31, 2021
BDDM: Bilateral Denoising Diffusion Models for Fast and High-Quality Speech Synthesis

Bilateral Denoising Diffusion Models (BDDMs) This is the official PyTorch implementation of the following paper: BDDM: BILATERAL DENOISING DIFFUSION M

172 Dec 23, 2022