Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization

Related tags

Deep Learningfishr
Overview

Fishr: Invariant Gradient Variances for Out-of-distribution Generalization

Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization | paper

Alexandre Ramé, Corentin Dancette, Matthieu Cord

Abstract

Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under fair evaluation protocols.

In this paper, we propose a new learning scheme to enforce domain invariance in the space of the gradients of the loss function: specifically, we introduce a regularization term that matches the domain-level variances of gradients across training domains. Critically, our strategy, named Fishr, exhibits close relations with the Fisher Information and the Hessian of the loss. We show that forcing domain-level gradient covariances to be similar during the learning procedure eventually aligns the domain-level loss landscapes locally around the final weights.

Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. In particular, Fishr improves the state of the art on the DomainBed benchmark and performs significantly better than Empirical Risk Minimization.

Installation

Requirements overview

Our implementation relies on the BackPACK package in PyTorch to easily compute gradient variances.

  • python == 3.7.10
  • torch == 1.8.1
  • torchvision == 0.9.1
  • backpack-for-pytorch == 1.3.0
  • numpy == 1.20.2

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/fishr.git
  1. Install this repository and the dependencies using pip:
$ conda create --name fishr python=3.7.10
$ conda activate fishr
$ cd fishr
$ pip install -r requirements.txt

With this, you can edit the Fishr code on the fly.

Overview

This github enables the replication of our two main experiments: (1) on Colored MNIST in the setup defined by IRM and (2) on the DomainBed benchmark.

Colored MNIST in the IRM setup

We first validate that Fishr tackles distribution shifts on the synthetic Colored MNIST.

Main results (Table 2 in Section 6.A)

To reproduce the results from Table 2, call python3 coloredmnist/train_coloredmnist.py --algorithm $algorithm where algorithm is either:

Results will be printed at the end of the script, averaged over 10 runs. Note that all hyperparameters are taken from the seminal IRM implementation.

    Method | Train acc. | Test acc.  | Gray test acc.
   --------|------------|------------|----------------
    ERM    | 86.4 ± 0.2 | 14.0 ± 0.7 |   71.0 ± 0.7
    IRM    | 71.0 ± 0.5 | 65.6 ± 1.8 |   66.1 ± 0.2
    V-REx  | 71.7 ± 1.5 | 67.2 ± 1.5 |   68.6 ± 2.2
    Fishr  | 71.0 ± 0.9 | 69.5 ± 1.0 |   70.2 ± 1.1

Without label flipping (Table 5 in Appendix C.2.3)

The script coloredmnist.train_coloredmnist also accepts as input the argument --label_flipping_prob which defines the label flipping probability. By default, it's 0.25, so to reproduce the results from Table 5 you should set --label_flipping_prob 0.

Fishr variants (Table 6 in Appendix C.2.4)

This table considers two additional Fishr variants, reproduced with algorithm set to:

  • fishr_offdiagonal for Fishr but without centering the gradient variances
  • fishr_notcentered for Fishr but on the full covariance rather than only the diagonal

DomainBed

DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in In Search of Lost Domain Generalization. Instructions below are copied and adapted from the official github.

Algorithms and hyperparameter grids

We added Fishr as a new algorithm here, and defined Fishr's hyperparameter grids here, as defined in Table 7 in Appendix D.

Datasets

We ran Fishr on following datasets:

Launch training

Download the datasets:

python3 -m domainbed.scripts.download\
       --data_dir=/my/data/dir

Train a model for debugging:

python3 -m domainbed.scripts.train\
       --data_dir=/my/data/dir/\
       --algorithm Fishr\
       --dataset ColoredMNIST\
       --test_env 2

Launch a sweep for hyperparameter search:

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/data/dir/\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher
       --datasets ColoredMNIST\
       --algorithms Fishr

Here, MyLauncher is your cluster's command launcher, as implemented in command_launchers.py.

Performances inspection (Tables 3 and 4 in Section 6.B.2, Tables in Appendix G)

To view the results of your sweep:

python -m domainbed.scripts.collect_results\
       --input_dir=/my/sweep/output/path

We inspect performances using following model selection criteria, that differ in what data is used to choose the best hyper-parameters for a given model:

  • OracleSelectionMethod (Oracle): A random subset from the data of the test domain.
  • IIDAccuracySelectionMethod (Training): A random subset from the data of the training domains.

Critically, Fishr performs consistently better than Empirical Risk Minimization.

Model selection Algorithm Colored MNIST Rotated MNIST VLCS PACS OfficeHome TerraIncognita DomainNet Avg
Oracle ERM 57.8 ± 0.2 97.8 ± 0.1 77.6 ± 0.3 86.7 ± 0.3 66.4 ± 0.5 53.0 ± 0.3 41.3 ± 0.1 68.7
Oracle Fishr 68.8 ± 1.4 97.8 ± 0.1 78.2 ± 0.2 86.9 ± 0.2 68.2 ± 0.2 53.6 ± 0.4 41.8 ± 0.2 70.8
Training ERM 51.5 ± 0.1 98.0 ± 0.0 77.5 ± 0.4 85.5 ± 0.2 66.5 ± 0.3 46.1 ± 1.8 40.9 ± 0.1 66.6
Training Fishr 52.0 ± 0.2 97.8 ± 0.0 77.8 ± 0.1 85.5 ± 0.4 67.8 ± 0.1 47.4 ± 1.6 41.7 ± 0.0 67.1

Conclusion

We addressed the task of out-of-distribution generalization for computer vision classification tasks. We derive a new and simple regularization - Fishr - that matches the gradient variances across domains as a proxy for matching domain-level Hessians. Our scalable strategy reaches state-of-the-art performances on the DomainBed benchmark and performs better than ERM. Our empirical experiments suggest that Fishr regularization would consistently improve a deep classifier in real-world applications when dealing with data from multiple domains. If you need help to use Fishr, please open an issue or contact [email protected].

Citation

If you find this code useful for your research, please consider citing our work (under review):

@article{rame2021ishr,
    title={Fishr: Invariant Gradient Variances for Out-of-distribution Generalization},
    author={Alexandre Rame and Corentin Dancette and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2109.02934}
}
Interpolation-based reduced-order models

Interpolation-reduced-order-models Interpolation-based reduced-order models High-fidelity computational fluid dynamics (CFD) solutions are time consum

Donovan Blais 1 Jan 10, 2022
Self-Supervised Image Denoising via Iterative Data Refinement

Self-Supervised Image Denoising via Iterative Data Refinement Yi Zhang1, Dasong Li1, Ka Lung Law2, Xiaogang Wang1, Hongwei Qin2, Hongsheng Li1 1CUHK-S

Zhang Yi 72 Jan 01, 2023
Source code of generalized shuffled linear regression

Generalized-Shuffled-Linear-Regression Code for the ICCV 2021 paper: Generalized Shuffled Linear Regression. Authors: Feiran Li, Kent Fujiwara, Fumio

FEI 7 Oct 26, 2022
FEMDA: Robust classification with Flexible Discriminant Analysis in heterogeneous data

FEMDA: Robust classification with Flexible Discriminant Analysis in heterogeneous data. Flexible EM-Inspired Discriminant Analysis is a robust supervised classification algorithm that performs well i

0 Sep 06, 2022
PyTorch reimplementation of minimal-hand (CVPR2020)

Minimal Hand Pytorch Unofficial PyTorch reimplementation of minimal-hand (CVPR2020). you can also find in youtube or bilibili bare hand youtube or bil

Hao Meng 228 Dec 29, 2022
The authors' implementation of Unsupervised Adversarial Learning of 3D Human Pose from 2D Joint Locations

Unsupervised Adversarial Learning of 3D Human Pose from 2D Joint Locations This is the authors' implementation of Unsupervised Adversarial Learning of

Dwango Media Village 140 Dec 07, 2022
This repository provides data for the VAW dataset as described in the CVPR 2021 paper titled "Learning to Predict Visual Attributes in the Wild"

Visual Attributes in the Wild (VAW) This repository provides data for the VAW dataset as described in the CVPR 2021 Paper: Learning to Predict Visual

Adobe Research 36 Dec 30, 2022
DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

Microsoft 8.4k Jan 01, 2023
Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

ETSformer - Pytorch Implementation of ETSformer, state of the art time-series Transformer, in Pytorch Install $ pip install etsformer-pytorch Usage im

Phil Wang 121 Dec 30, 2022
PenguinSpeciesPredictionML - Basic model to predict Penguin species based on beak size and sex.

Penguin Species Prediction (ML) 🐧 👨🏽‍💻 What? 💻 This project is a basic model using sklearn methods to predict Penguin species based on beak size

Tucker Paron 0 Jan 08, 2022
Ranger deep learning optimizer rewrite to use newest components

Ranger21 - integrating the latest deep learning components into a single optimizer Ranger deep learning optimizer rewrite to use newest components Ran

Less Wright 266 Dec 28, 2022
A nutritional label for food for thought.

Lexiscore As a first effort in tackling the theme of information overload in content consumption, I've been working on the lexiscore: a nutritional la

Paul Bricman 34 Nov 08, 2022
A parallel framework for population-based multi-agent reinforcement learning.

MALib: A parallel framework for population-based multi-agent reinforcement learning MALib is a parallel framework of population-based learning nested

MARL @ SJTU 348 Jan 08, 2023
Breaking Shortcut: Exploring Fully Convolutional Cycle-Consistency for Video Correspondence Learning

Breaking Shortcut: Exploring Fully Convolutional Cycle-Consistency for Video Correspondence Learning Yansong Tang *, Zhenyu Jiang *, Zhenda Xie *, Yue

Zhenyu Jiang 12 Nov 16, 2022
Data reduction pipeline for KOALA on the AAT.

KOALA KOALA, the Kilofibre Optical AAT Lenslet Array, is a wide-field, high efficiency, integral field unit used by the AAOmega spectrograph on the 3.

4 Sep 26, 2022
Ankou: Guiding Grey-box Fuzzing towards Combinatorial Difference

Ankou Ankou is a source-based grey-box fuzzer. It intends to use a more rich fitness function by going beyond simple branch coverage and considering t

SoftSec Lab 54 Dec 24, 2022
Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs This is an implemetation of the paper Few-shot Relation Extraction via Baye

MilaGraph 36 Nov 22, 2022
LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation

LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation by Junjue Wang, Zhuo Zheng, Ailong Ma, Xiaoyan Lu, and Yanfei Zh

Payphone 8 Nov 21, 2022
一些经典的CTR算法的复现; LR, FM, FFM, AFM, DeepFM,xDeepFM, PNN, DCN, DCNv2, DIFM, AutoInt, FiBiNet,AFN,ONN,DIN, DIEN ... (pytorch, tf2.0)

CTR Algorithm 根据论文, 博客, 知乎等方式学习一些CTR相关的算法 理解原理并自己动手来实现一遍 pytorch & tf2.0 保持一颗学徒的心! Schedule Model pytorch tensorflow2.0 paper LR ✔️ ✔️ \ FM ✔️ ✔️ Fac

luo han 149 Dec 20, 2022
learned_optimization: Training and evaluating learned optimizers in JAX

learned_optimization: Training and evaluating learned optimizers in JAX learned_optimization is a research codebase for training learned optimizers. I

Google 533 Dec 30, 2022