Fishr: Invariant Gradient Variances for Out-of-distribution Generalization
Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization | paper
Alexandre Ramé, Corentin Dancette, Matthieu Cord
Abstract
Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under fair evaluation protocols.
In this paper, we propose a new learning scheme to enforce domain invariance in the space of the gradients of the loss function: specifically, we introduce a regularization term that matches the domain-level variances of gradients across training domains. Critically, our strategy, named Fishr, exhibits close relations with the Fisher Information and the Hessian of the loss. We show that forcing domain-level gradient covariances to be similar during the learning procedure eventually aligns the domain-level loss landscapes locally around the final weights.
Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. In particular, Fishr improves the state of the art on the DomainBed benchmark and performs significantly better than Empirical Risk Minimization.
Installation
Requirements overview
Our implementation relies on the BackPACK package in PyTorch to easily compute gradient variances.
- python == 3.7.10
- torch == 1.8.1
- torchvision == 0.9.1
- backpack-for-pytorch == 1.3.0
- numpy == 1.20.2
Procedure
- Clone the repo:
$ git clone https://github.com/alexrame/fishr.git
- Install this repository and the dependencies using pip:
$ conda create --name fishr python=3.7.10
$ conda activate fishr
$ cd fishr
$ pip install -r requirements.txt
With this, you can edit the Fishr code on the fly.
Overview
This github enables the replication of our two main experiments: (1) on Colored MNIST in the setup defined by IRM and (2) on the DomainBed benchmark.
Colored MNIST in the IRM setup
We first validate that Fishr tackles distribution shifts on the synthetic Colored MNIST.
Main results (Table 2 in Section 6.A)
To reproduce the results from Table 2, call python3 coloredmnist/train_coloredmnist.py --algorithm $algorithm where algorithm is either:
- ermfor Empirical Risk Minimization
- irmfor Invariant Risk Minimization
- rexfor Out-of-Distribution Generalization via Risk Extrapolation
- fishrfor our proposed Fishr
Results will be printed at the end of the script, averaged over 10 runs. Note that all hyperparameters are taken from the seminal IRM implementation.
    Method | Train acc. | Test acc.  | Gray test acc.
   --------|------------|------------|----------------
    ERM    | 86.4 ± 0.2 | 14.0 ± 0.7 |   71.0 ± 0.7
    IRM    | 71.0 ± 0.5 | 65.6 ± 1.8 |   66.1 ± 0.2
    V-REx  | 71.7 ± 1.5 | 67.2 ± 1.5 |   68.6 ± 2.2
    Fishr  | 71.0 ± 0.9 | 69.5 ± 1.0 |   70.2 ± 1.1
Without label flipping (Table 5 in Appendix C.2.3)
The script coloredmnist.train_coloredmnist also accepts as input the argument --label_flipping_prob which defines the label flipping probability. By default, it's 0.25, so to reproduce the results from Table 5 you should set --label_flipping_prob 0.
Fishr variants (Table 6 in Appendix C.2.4)
This table considers two additional Fishr variants, reproduced with algorithm set to:
- fishr_offdiagonalfor Fishr but without centering the gradient variances
- fishr_notcenteredfor Fishr but on the full covariance rather than only the diagonal
DomainBed
DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in In Search of Lost Domain Generalization. Instructions below are copied and adapted from the official github.
Algorithms and hyperparameter grids
We added Fishr as a new algorithm here, and defined Fishr's hyperparameter grids here, as defined in Table 7 in Appendix D.
Datasets
We ran Fishr on following datasets:
- Rotated MNIST (Ghifary et al., 2015)
- Colored MNIST (Arjovsky et al., 2019)
- VLCS (Fang et al., 2013)
- PACS (Li et al., 2017)
- OfficeHome (Venkateswara et al., 2017)
- A TerraIncognita (Beery et al., 2018) subset
- DomainNet (Peng et al., 2019)
Launch training
Download the datasets:
python3 -m domainbed.scripts.download\
       --data_dir=/my/data/dir
Train a model for debugging:
python3 -m domainbed.scripts.train\
       --data_dir=/my/data/dir/\
       --algorithm Fishr\
       --dataset ColoredMNIST\
       --test_env 2
Launch a sweep for hyperparameter search:
python -m domainbed.scripts.sweep launch\
       --data_dir=/my/data/dir/\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher
       --datasets ColoredMNIST\
       --algorithms Fishr
Here, MyLauncher is your cluster's command launcher, as implemented in command_launchers.py.
Performances inspection (Tables 3 and 4 in Section 6.B.2, Tables in Appendix G)
To view the results of your sweep:
python -m domainbed.scripts.collect_results\
       --input_dir=/my/sweep/output/path
We inspect performances using following model selection criteria, that differ in what data is used to choose the best hyper-parameters for a given model:
- OracleSelectionMethod(- Oracle): A random subset from the data of the test domain.
- IIDAccuracySelectionMethod(- Training): A random subset from the data of the training domains.
Critically, Fishr performs consistently better than Empirical Risk Minimization.
| Model selection | Algorithm | Colored MNIST | Rotated MNIST | VLCS | PACS | OfficeHome | TerraIncognita | DomainNet | Avg | 
|---|---|---|---|---|---|---|---|---|---|
| Oracle | ERM | 57.8 ± 0.2 | 97.8 ± 0.1 | 77.6 ± 0.3 | 86.7 ± 0.3 | 66.4 ± 0.5 | 53.0 ± 0.3 | 41.3 ± 0.1 | 68.7 | 
| Oracle | Fishr | 68.8 ± 1.4 | 97.8 ± 0.1 | 78.2 ± 0.2 | 86.9 ± 0.2 | 68.2 ± 0.2 | 53.6 ± 0.4 | 41.8 ± 0.2 | 70.8 | 
| Training | ERM | 51.5 ± 0.1 | 98.0 ± 0.0 | 77.5 ± 0.4 | 85.5 ± 0.2 | 66.5 ± 0.3 | 46.1 ± 1.8 | 40.9 ± 0.1 | 66.6 | 
| Training | Fishr | 52.0 ± 0.2 | 97.8 ± 0.0 | 77.8 ± 0.1 | 85.5 ± 0.4 | 67.8 ± 0.1 | 47.4 ± 1.6 | 41.7 ± 0.0 | 67.1 | 
Conclusion
We addressed the task of out-of-distribution generalization for computer vision classification tasks. We derive a new and simple regularization - Fishr - that matches the gradient variances across domains as a proxy for matching domain-level Hessians. Our scalable strategy reaches state-of-the-art performances on the DomainBed benchmark and performs better than ERM. Our empirical experiments suggest that Fishr regularization would consistently improve a deep classifier in real-world applications when dealing with data from multiple domains. If you need help to use Fishr, please open an issue or contact [email protected].
Citation
If you find this code useful for your research, please consider citing our work (under review):
@article{rame2021ishr,
    title={Fishr: Invariant Gradient Variances for Out-of-distribution Generalization},
    author={Alexandre Rame and Corentin Dancette and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2109.02934}
}
