Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization

Related tags

Deep Learningfishr
Overview

Fishr: Invariant Gradient Variances for Out-of-distribution Generalization

Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization | paper

Alexandre Ramé, Corentin Dancette, Matthieu Cord

Abstract

Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under fair evaluation protocols.

In this paper, we propose a new learning scheme to enforce domain invariance in the space of the gradients of the loss function: specifically, we introduce a regularization term that matches the domain-level variances of gradients across training domains. Critically, our strategy, named Fishr, exhibits close relations with the Fisher Information and the Hessian of the loss. We show that forcing domain-level gradient covariances to be similar during the learning procedure eventually aligns the domain-level loss landscapes locally around the final weights.

Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. In particular, Fishr improves the state of the art on the DomainBed benchmark and performs significantly better than Empirical Risk Minimization.

Installation

Requirements overview

Our implementation relies on the BackPACK package in PyTorch to easily compute gradient variances.

  • python == 3.7.10
  • torch == 1.8.1
  • torchvision == 0.9.1
  • backpack-for-pytorch == 1.3.0
  • numpy == 1.20.2

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/fishr.git
  1. Install this repository and the dependencies using pip:
$ conda create --name fishr python=3.7.10
$ conda activate fishr
$ cd fishr
$ pip install -r requirements.txt

With this, you can edit the Fishr code on the fly.

Overview

This github enables the replication of our two main experiments: (1) on Colored MNIST in the setup defined by IRM and (2) on the DomainBed benchmark.

Colored MNIST in the IRM setup

We first validate that Fishr tackles distribution shifts on the synthetic Colored MNIST.

Main results (Table 2 in Section 6.A)

To reproduce the results from Table 2, call python3 coloredmnist/train_coloredmnist.py --algorithm $algorithm where algorithm is either:

Results will be printed at the end of the script, averaged over 10 runs. Note that all hyperparameters are taken from the seminal IRM implementation.

    Method | Train acc. | Test acc.  | Gray test acc.
   --------|------------|------------|----------------
    ERM    | 86.4 ± 0.2 | 14.0 ± 0.7 |   71.0 ± 0.7
    IRM    | 71.0 ± 0.5 | 65.6 ± 1.8 |   66.1 ± 0.2
    V-REx  | 71.7 ± 1.5 | 67.2 ± 1.5 |   68.6 ± 2.2
    Fishr  | 71.0 ± 0.9 | 69.5 ± 1.0 |   70.2 ± 1.1

Without label flipping (Table 5 in Appendix C.2.3)

The script coloredmnist.train_coloredmnist also accepts as input the argument --label_flipping_prob which defines the label flipping probability. By default, it's 0.25, so to reproduce the results from Table 5 you should set --label_flipping_prob 0.

Fishr variants (Table 6 in Appendix C.2.4)

This table considers two additional Fishr variants, reproduced with algorithm set to:

  • fishr_offdiagonal for Fishr but without centering the gradient variances
  • fishr_notcentered for Fishr but on the full covariance rather than only the diagonal

DomainBed

DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in In Search of Lost Domain Generalization. Instructions below are copied and adapted from the official github.

Algorithms and hyperparameter grids

We added Fishr as a new algorithm here, and defined Fishr's hyperparameter grids here, as defined in Table 7 in Appendix D.

Datasets

We ran Fishr on following datasets:

Launch training

Download the datasets:

python3 -m domainbed.scripts.download\
       --data_dir=/my/data/dir

Train a model for debugging:

python3 -m domainbed.scripts.train\
       --data_dir=/my/data/dir/\
       --algorithm Fishr\
       --dataset ColoredMNIST\
       --test_env 2

Launch a sweep for hyperparameter search:

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/data/dir/\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher
       --datasets ColoredMNIST\
       --algorithms Fishr

Here, MyLauncher is your cluster's command launcher, as implemented in command_launchers.py.

Performances inspection (Tables 3 and 4 in Section 6.B.2, Tables in Appendix G)

To view the results of your sweep:

python -m domainbed.scripts.collect_results\
       --input_dir=/my/sweep/output/path

We inspect performances using following model selection criteria, that differ in what data is used to choose the best hyper-parameters for a given model:

  • OracleSelectionMethod (Oracle): A random subset from the data of the test domain.
  • IIDAccuracySelectionMethod (Training): A random subset from the data of the training domains.

Critically, Fishr performs consistently better than Empirical Risk Minimization.

Model selection Algorithm Colored MNIST Rotated MNIST VLCS PACS OfficeHome TerraIncognita DomainNet Avg
Oracle ERM 57.8 ± 0.2 97.8 ± 0.1 77.6 ± 0.3 86.7 ± 0.3 66.4 ± 0.5 53.0 ± 0.3 41.3 ± 0.1 68.7
Oracle Fishr 68.8 ± 1.4 97.8 ± 0.1 78.2 ± 0.2 86.9 ± 0.2 68.2 ± 0.2 53.6 ± 0.4 41.8 ± 0.2 70.8
Training ERM 51.5 ± 0.1 98.0 ± 0.0 77.5 ± 0.4 85.5 ± 0.2 66.5 ± 0.3 46.1 ± 1.8 40.9 ± 0.1 66.6
Training Fishr 52.0 ± 0.2 97.8 ± 0.0 77.8 ± 0.1 85.5 ± 0.4 67.8 ± 0.1 47.4 ± 1.6 41.7 ± 0.0 67.1

Conclusion

We addressed the task of out-of-distribution generalization for computer vision classification tasks. We derive a new and simple regularization - Fishr - that matches the gradient variances across domains as a proxy for matching domain-level Hessians. Our scalable strategy reaches state-of-the-art performances on the DomainBed benchmark and performs better than ERM. Our empirical experiments suggest that Fishr regularization would consistently improve a deep classifier in real-world applications when dealing with data from multiple domains. If you need help to use Fishr, please open an issue or contact [email protected].

Citation

If you find this code useful for your research, please consider citing our work (under review):

@article{rame2021ishr,
    title={Fishr: Invariant Gradient Variances for Out-of-distribution Generalization},
    author={Alexandre Rame and Corentin Dancette and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2109.02934}
}
An Empirical Investigation of Model-to-Model Distribution Shifts in Trained Convolutional Filters

CNN-Filter-DB An Empirical Investigation of Model-to-Model Distribution Shifts in Trained Convolutional Filters Paul Gavrikov, Janis Keuper Paper: htt

Paul Gavrikov 18 Dec 30, 2022
Multi-modal Text Recognition Networks: Interactive Enhancements between Visual and Semantic Features

Multi-modal Text Recognition Networks: Interactive Enhancements between Visual and Semantic Features | paper | Official PyTorch implementation for Mul

48 Dec 28, 2022
Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training"

Saliency Guided Training Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training" by Aya Abdelsalam Ismail, Hector Cor

8 Sep 22, 2022
A stock generator that assess a list of stocks and returns the best stocks for investing and money allocations based on users choices of volatility, duration and number of stocks

Stock-Generator Please visit "Stock Generator.ipynb" for a clearer view and "Stock Generator.py" for scripts. The stock generator is designed to allow

jmengnyay 1 Aug 02, 2022
Seeing All the Angles: Learning Multiview Manipulation Policies for Contact-Rich Tasks from Demonstrations

Seeing All the Angles: Learning Multiview Manipulation Policies for Contact-Rich Tasks from Demonstrations Trevor Ablett, Daniel (Yifan) Zhai, Jonatha

STARS Laboratory 3 Feb 01, 2022
Code for the Paper: Conditional Variational Capsule Network for Open Set Recognition

Conditional Variational Capsule Network for Open Set Recognition This repository hosts the official code related to "Conditional Variational Capsule N

Guglielmo Camporese 35 Nov 21, 2022
git《USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation》(2020) GitHub: [fig2]

USD-Seg This project is an implement of paper USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation, based on FCOS detector f

Ruolin Ye 80 Nov 28, 2022
Place holder for HOPE: a human-centric and task-oriented MT evaluation framework using professional post-editing

HOPE: A Task-Oriented and Human-Centric Evaluation Framework Using Professional Post-Editing Towards More Effective MT Evaluation Place holder for dat

Lifeng Han 1 Apr 25, 2022
Co-GAIL: Learning Diverse Strategies for Human-Robot Collaboration

CoGAIL Table of Content Overview Installation Dataset Training Evaluation Trained Checkpoints Acknowledgement Citations License Overview This reposito

Jeremy Wang 29 Dec 24, 2022
This project deals with the detection of skin lesions within the ISICs dataset using YOLOv3 Object Detection with Darknet.

This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License. Skin Lesion detection using YOLO This project deal

Lalith Veerabhadrappa Badiger 1 Nov 22, 2021
Deep learning toolbox based on PyTorch for hyperspectral data classification.

Deep learning toolbox based on PyTorch for hyperspectral data classification.

Nicolas 304 Dec 28, 2022
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
[ICML 2020] DrRepair: Learning to Repair Programs from Error Messages

DrRepair: Learning to Repair Programs from Error Messages This repo provides the source code & data of our paper: Graph-based, Self-Supervised Program

Michihiro Yasunaga 155 Jan 08, 2023
Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow

xRBM Library Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow Installation Using pip: pip install xrbm Examples Tut

Omid Alemi 55 Dec 29, 2022
Library extending Jupyter notebooks to integrate with Apache TinkerPop and RDF SPARQL.

Graph Notebook: easily query and visualize graphs The graph notebook provides an easy way to interact with graph databases using Jupyter notebooks. Us

Amazon Web Services 501 Dec 28, 2022
A python implementation of Deep-Image-Analogy based on pytorch.

Deep-Image-Analogy This project is a python implementation of Deep Image Analogy.https://arxiv.org/abs/1705.01088. Some results Requirements python 3

Peng Lu 171 Dec 14, 2022
ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation

ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation This repository contains the source code of our paper, ESPNet (acc

Sachin Mehta 515 Dec 13, 2022
PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

Salesforce 1.3k Dec 31, 2022
Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research

Megaverse Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research. The efficient design of the engine enables ph

Aleksei Petrenko 191 Dec 23, 2022
Implementation of Online Label Smoothing in PyTorch

Online Label Smoothing Pytorch implementation of Online Label Smoothing (OLS) presented in Delving Deep into Label Smoothing. Introduction As the abst

83 Dec 14, 2022