Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization

Related tags

Deep Learningfishr
Overview

Fishr: Invariant Gradient Variances for Out-of-distribution Generalization

Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization | paper

Alexandre Ramé, Corentin Dancette, Matthieu Cord

Abstract

Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under fair evaluation protocols.

In this paper, we propose a new learning scheme to enforce domain invariance in the space of the gradients of the loss function: specifically, we introduce a regularization term that matches the domain-level variances of gradients across training domains. Critically, our strategy, named Fishr, exhibits close relations with the Fisher Information and the Hessian of the loss. We show that forcing domain-level gradient covariances to be similar during the learning procedure eventually aligns the domain-level loss landscapes locally around the final weights.

Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. In particular, Fishr improves the state of the art on the DomainBed benchmark and performs significantly better than Empirical Risk Minimization.

Installation

Requirements overview

Our implementation relies on the BackPACK package in PyTorch to easily compute gradient variances.

  • python == 3.7.10
  • torch == 1.8.1
  • torchvision == 0.9.1
  • backpack-for-pytorch == 1.3.0
  • numpy == 1.20.2

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/fishr.git
  1. Install this repository and the dependencies using pip:
$ conda create --name fishr python=3.7.10
$ conda activate fishr
$ cd fishr
$ pip install -r requirements.txt

With this, you can edit the Fishr code on the fly.

Overview

This github enables the replication of our two main experiments: (1) on Colored MNIST in the setup defined by IRM and (2) on the DomainBed benchmark.

Colored MNIST in the IRM setup

We first validate that Fishr tackles distribution shifts on the synthetic Colored MNIST.

Main results (Table 2 in Section 6.A)

To reproduce the results from Table 2, call python3 coloredmnist/train_coloredmnist.py --algorithm $algorithm where algorithm is either:

Results will be printed at the end of the script, averaged over 10 runs. Note that all hyperparameters are taken from the seminal IRM implementation.

    Method | Train acc. | Test acc.  | Gray test acc.
   --------|------------|------------|----------------
    ERM    | 86.4 ± 0.2 | 14.0 ± 0.7 |   71.0 ± 0.7
    IRM    | 71.0 ± 0.5 | 65.6 ± 1.8 |   66.1 ± 0.2
    V-REx  | 71.7 ± 1.5 | 67.2 ± 1.5 |   68.6 ± 2.2
    Fishr  | 71.0 ± 0.9 | 69.5 ± 1.0 |   70.2 ± 1.1

Without label flipping (Table 5 in Appendix C.2.3)

The script coloredmnist.train_coloredmnist also accepts as input the argument --label_flipping_prob which defines the label flipping probability. By default, it's 0.25, so to reproduce the results from Table 5 you should set --label_flipping_prob 0.

Fishr variants (Table 6 in Appendix C.2.4)

This table considers two additional Fishr variants, reproduced with algorithm set to:

  • fishr_offdiagonal for Fishr but without centering the gradient variances
  • fishr_notcentered for Fishr but on the full covariance rather than only the diagonal

DomainBed

DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in In Search of Lost Domain Generalization. Instructions below are copied and adapted from the official github.

Algorithms and hyperparameter grids

We added Fishr as a new algorithm here, and defined Fishr's hyperparameter grids here, as defined in Table 7 in Appendix D.

Datasets

We ran Fishr on following datasets:

Launch training

Download the datasets:

python3 -m domainbed.scripts.download\
       --data_dir=/my/data/dir

Train a model for debugging:

python3 -m domainbed.scripts.train\
       --data_dir=/my/data/dir/\
       --algorithm Fishr\
       --dataset ColoredMNIST\
       --test_env 2

Launch a sweep for hyperparameter search:

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/data/dir/\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher
       --datasets ColoredMNIST\
       --algorithms Fishr

Here, MyLauncher is your cluster's command launcher, as implemented in command_launchers.py.

Performances inspection (Tables 3 and 4 in Section 6.B.2, Tables in Appendix G)

To view the results of your sweep:

python -m domainbed.scripts.collect_results\
       --input_dir=/my/sweep/output/path

We inspect performances using following model selection criteria, that differ in what data is used to choose the best hyper-parameters for a given model:

  • OracleSelectionMethod (Oracle): A random subset from the data of the test domain.
  • IIDAccuracySelectionMethod (Training): A random subset from the data of the training domains.

Critically, Fishr performs consistently better than Empirical Risk Minimization.

Model selection Algorithm Colored MNIST Rotated MNIST VLCS PACS OfficeHome TerraIncognita DomainNet Avg
Oracle ERM 57.8 ± 0.2 97.8 ± 0.1 77.6 ± 0.3 86.7 ± 0.3 66.4 ± 0.5 53.0 ± 0.3 41.3 ± 0.1 68.7
Oracle Fishr 68.8 ± 1.4 97.8 ± 0.1 78.2 ± 0.2 86.9 ± 0.2 68.2 ± 0.2 53.6 ± 0.4 41.8 ± 0.2 70.8
Training ERM 51.5 ± 0.1 98.0 ± 0.0 77.5 ± 0.4 85.5 ± 0.2 66.5 ± 0.3 46.1 ± 1.8 40.9 ± 0.1 66.6
Training Fishr 52.0 ± 0.2 97.8 ± 0.0 77.8 ± 0.1 85.5 ± 0.4 67.8 ± 0.1 47.4 ± 1.6 41.7 ± 0.0 67.1

Conclusion

We addressed the task of out-of-distribution generalization for computer vision classification tasks. We derive a new and simple regularization - Fishr - that matches the gradient variances across domains as a proxy for matching domain-level Hessians. Our scalable strategy reaches state-of-the-art performances on the DomainBed benchmark and performs better than ERM. Our empirical experiments suggest that Fishr regularization would consistently improve a deep classifier in real-world applications when dealing with data from multiple domains. If you need help to use Fishr, please open an issue or contact [email protected].

Citation

If you find this code useful for your research, please consider citing our work (under review):

@article{rame2021ishr,
    title={Fishr: Invariant Gradient Variances for Out-of-distribution Generalization},
    author={Alexandre Rame and Corentin Dancette and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2109.02934}
}
implementation for paper "ShelfNet for fast semantic segmentation"

ShelfNet-lightweight for paper (ShelfNet for fast semantic segmentation) This repo contains implementation of ShelfNet-lightweight models for real-tim

Juntang Zhuang 252 Sep 16, 2022
An official implementation of MobileStyleGAN in PyTorch

MobileStyleGAN: A Lightweight Convolutional Neural Network for High-Fidelity Image Synthesis Official PyTorch Implementation The accompanying videos c

Sergei Belousov 602 Jan 07, 2023
Transformers provides thousands of pretrained models to perform tasks on different modalities such as text, vision, and audio.

English | 简体中文 | 繁體中文 | 한국어 State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained models

Clara Meister 50 Nov 12, 2022
PyTorch implementation of the Quasi-Recurrent Neural Network - up to 16 times faster than NVIDIA's cuDNN LSTM

Quasi-Recurrent Neural Network (QRNN) for PyTorch Updated to support multi-GPU environments via DataParallel - see the the multigpu_dataparallel.py ex

Salesforce 1.3k Dec 28, 2022
Classifying cat and dog images using Kaggle dataset

PyTorch Image Classification Classifies an image as containing either a dog or a cat (using Kaggle's public dataset), but could easily be extended to

Robert Coleman 74 Nov 22, 2022
Unoffical reMarkable AddOn for Firefox.

reMarkable for Firefox (Download) This repo converts the offical reMarkable Chrome Extension into a Firefox AddOn published here under the name "Unoff

Jelle Schutter 45 Nov 28, 2022
This implements one of result networks from Large-scale evolution of image classifiers

Exotic structured image classifier This implements one of result networks from Large-scale evolution of image classifiers by Esteban Real, et. al. Req

54 Nov 25, 2022
Pyramid R-CNN: Towards Better Performance and Adaptability for 3D Object Detection

Pyramid R-CNN: Towards Better Performance and Adaptability for 3D Object Detection

61 Jan 07, 2023
SemEval2022 Patronizing and Condescending Language (PCL) Detection

SemEval2022 Patronizing and Condescending Language (PCL) Detection This task is from SemEval 2022. What is Patronizing and Condescending Language (PCL

Daniel Saeedi 0 Aug 05, 2022
GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification

GalaXC GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification @InProceedings{Saini21, author = {Saini, D. and Jain,

Extreme Classification 28 Dec 05, 2022
Open & Efficient for Framework for Aspect-based Sentiment Analysis

PyABSA - Open & Efficient for Framework for Aspect-based Sentiment Analysis Fast & Low Memory requirement & Enhanced implementation of Local Context F

YangHeng 567 Jan 07, 2023
SCALoss: Side and Corner Aligned Loss for Bounding Box Regression (AAAI2022).

SCALoss PyTorch implementation of the paper "SCALoss: Side and Corner Aligned Loss for Bounding Box Regression" (AAAI 2022). Introduction IoU-based lo

TuZheng 20 Sep 07, 2022
NaturalCC is a sequence modeling toolkit that allows researchers and developers to train custom models

NaturalCC NaturalCC is a sequence modeling toolkit that allows researchers and developers to train custom models for many software engineering tasks,

159 Dec 28, 2022
(Arxiv 2021) NeRF--: Neural Radiance Fields Without Known Camera Parameters

NeRF--: Neural Radiance Fields Without Known Camera Parameters Project Page | Arxiv | Colab Notebook | Data Zirui Wang¹, Shangzhe Wu², Weidi Xie², Min

Active Vision Laboratory 411 Dec 26, 2022
An All-MLP solution for Vision, from Google AI

MLP Mixer - Pytorch An All-MLP solution for Vision, from Google AI, in Pytorch. No convolutions nor attention needed! Yannic Kilcher video Install $ p

Phil Wang 784 Jan 06, 2023
Out-of-Distribution Generalization of Chest X-ray Using Risk Extrapolation

OoD_Gen-Chest_Xray Out-of-Distribution Generalization of Chest X-ray Using Risk Extrapolation Requirements (Installations) Install the following libra

Enoch Tetteh 2 Oct 01, 2022
Inhomogeneous Social Recommendation with Hypergraph Convolutional Networks

Inhomogeneous Social Recommendation with Hypergraph Convolutional Networks This is our Pytorch implementation for the paper: Zirui Zhu, Chen Gao, Xu C

Zirui Zhu 3 Dec 30, 2022
Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Tom-R.T.Kvalvaag 2 Dec 17, 2021
Info and sample codes for "NTU RGB+D Action Recognition Dataset"

"NTU RGB+D" Action Recognition Dataset "NTU RGB+D 120" Action Recognition Dataset "NTU RGB+D" is a large-scale dataset for human action recognition. I

Amir Shahroudy 578 Dec 30, 2022
The repository for our EMNLP 2021 paper "Finnish Dialect Identification: The Effect of Audio and Text"

Finnish Dialect Identification The repository for our EMNLP 2021 paper "Finnish Dialect Identification: The Effect of Audio and Text". We present a te

Rootroo Ltd 2 Dec 25, 2021