Official implementation of paper Gradient Matching for Domain Generalization

Related tags

Deep Learningfish
Overview

Gradient Matching for Domain Generalisation

This is the official PyTorch implementation of Gradient Matching for Domain Generalisation. In our paper, we propose an inter-domain gradient matching (IDGM) objective that targets domain generalization by maximizing the inner product between gradients from different domains. To avoid computing the expensive second-order derivative of the IDGM objective, we derive a simpler first-order algorithm named Fish that approximates its optimization.

This repository contains code to reproduce the main results of our paper.

Dependencies

(Recommended) You can setup up conda environment with all required dependencies using environment.yml:

conda env create -f environment.yml
conda activate fish

Otherwise you can also install the following packages manually:

python=3.7.10
numpy=1.20.2
pytorch=1.8.1
torchaudio=0.8.1
torchvision=0.9.1
torch-cluster=1.5.9
torch-geometric=1.7.0
torch-scatter=2.0.6
torch-sparse=0.6.9
wilds=1.1.0
scikit-learn=0.24.2
scipy=1.6.3
seaborn=0.11.1
tqdm=4.61.0

Running Experiments

We offer options to train using our proposed method Fish or by using Empirical Risk Minimisation baseline. This can be specified by the --algorithm flag (either fish or erm).

CdSprites-N

We propose this simple shape-color dataset based on the dSprites dataset, which contains a collection of white 2D sprites of different shapes, scales, rotations and positions. The dataset contains N domains, where N can be specified. The goal is to classify the shape of the sprites, and there is a shape-color deterministic matching that is specific per domain. This way we have shape as the invariant feature and color as the spurious feature. On the test set, however, this correlation between color and shape is removed. See the image below for an illustration.

cdsprites

The CdSprites-N dataset can be downloaded here. After downloading, please extract the zip file to your preferred data dir (e.g. <your_data_dir>/cdsprites). The following command runs an experiment using Fish with number of domains N=15:

python main.py --dataset cdsprites --algorithm fish --data-dir <your_data_dir> --num-domains 15

The number of domains you can choose from are: N = 5, 10, 15, 20, 25, 30, 35, 40, 45, 50.

WILDS

We include the following 6 datasets from the WILDS benchmark: amazon, camelyon, civil, fmow, iwildcam, poverty. The datasets can be downloaded automatically to a specified data folder. For instance, to train with Fish on Amazon dataset, simply run:

python main.py --dataset amazon --algorithm fish --data-dir <your_data_dir>

This should automatically download the Amazon dataset to <your_data_dir>/wilds. Experiments on other datasets can be ran by the following commands:

python main.py --dataset camelyon --algorithm fish --data-dir <your_data_dir>
python main.py --dataset civil --algorithm fish --data-dir <your_data_dir>
python main.py --dataset fmow --algorithm fish --data-dir <your_data_dir>
python main.py --dataset iwildcam --algorithm fish --data-dir <your_data_dir>
python main.py --dataset poverty --algorithm fish --data-dir <your_data_dir>

Alternatively, you can also download the datasets to <your_data_dir>/wilds manually by following the instructions here. See current results on WILDS here: image

DomainBed

For experiments on datasets including CMNIST, RMNIST, VLCS, PACS, OfficeHome, TerraInc and DomainNet, we implemented Fish on the DomainBed benchmark (see here) and you can compare our algorithm against up to 20 SOTA baselines. See current results on DomainBed here:

image

Citation

If you make use of this code in your research, we would appreciate if you considered citing the paper that is most relevant to your work:

@article{shi2021gradient,
	title="Gradient Matching for Domain Generalization.",
	author="Yuge {Shi} and Jeffrey {Seely} and Philip H. S. {Torr} and N. {Siddharth} and Awni {Hannun} and Nicolas {Usunier} and Gabriel {Synnaeve}",
	journal="arXiv preprint arXiv:2104.09937",
	year="2021"}

Contributions

We welcome contributions via pull requests. Please email [email protected] or [email protected] for any question/request.

Simple Tensorflow implementation of Toward Spatially Unbiased Generative Models (ICCV 2021)

Spatial unbiased GANs — Simple TensorFlow Implementation [Paper] : Toward Spatially Unbiased Generative Models (ICCV 2021) Abstract Recent image gener

Junho Kim 16 Apr 15, 2022
CONditionals for Ordinal Regression and classification in tensorflow

Condor Ordinal regression in Tensorflow Keras Tensorflow Keras implementation of CONDOR Ordinal Regression (aka ordinal classification) by Garrett Jen

9 Jul 31, 2022
This is 2nd term discrete maths project done by UCU students that uses backtracking to solve various problems.

Backtracking Project Sponsors This is a project made by UCU students: Olha Liuba - crossword solver implementation Hanna Yershova - sudoku solver impl

Dasha 4 Oct 17, 2021
Causal estimators for use with WhyNot

WhyNot Estimators A collection of causal inference estimators implemented in Python and R to pair with the Python causal inference library whynot. For

ZYKLS 8 Apr 06, 2022
Collections for the lasted paper about multi-view clustering methods (papers, codes)

Multi-View Clustering Papers Collections for the lasted paper about multi-view clustering methods (papers, codes). There also exists some repositories

Andrew Guan 10 Sep 20, 2022
Convert onnx models to pytorch.

onnx2torch onnx2torch is an ONNX to PyTorch converter. Our converter: Is easy to use – Convert the ONNX model with the function call convert; Is easy

ENOT 264 Dec 30, 2022
Simulation of self-focusing of laser beams in condensed media

What is it? Program for scientific research, which allows to simulate the phenomenon of self-focusing of different laser beams (including Gaussian, ri

Evgeny Vasilyev 13 Dec 24, 2022
A Python implementation of active inference for Markov Decision Processes

A Python package for simulating Active Inference agents in Markov Decision Process environments. Please see our companion preprint on arxiv for an ove

235 Dec 21, 2022
Rule Extraction Methods for Interactive eXplainability

REMIX: Rule Extraction Methods for Interactive eXplainability This repository contains a variety of tools and methods for extracting interpretable rul

Mateo Espinosa Zarlenga 21 Jan 03, 2023
Python-based Informatics Kit for Analysing Chemical Units

INSTALLATION Python-based Informatics Kit for the Analysis of Chemical Units Step 1: Make a conda environment: conda create -n pikachu python=3.9 cond

47 Dec 23, 2022
This is a file about Unet implemented in Pytorch

Unet this is an implemetion of Unet in Pytorch and it's architecture is as follows which is the same with paper of Unet component of Unet Convolution

Dragon 1 Dec 03, 2021
EqGAN - Improving GAN Equilibrium by Raising Spatial Awareness

EqGAN - Improving GAN Equilibrium by Raising Spatial Awareness Improving GAN Equilibrium by Raising Spatial Awareness Jianyuan Wang, Ceyuan Yang, Ying

GenForce: May Generative Force Be with You 149 Dec 19, 2022
PolyTrack: Tracking with Bounding Polygons

PolyTrack: Tracking with Bounding Polygons Abstract In this paper, we present a novel method called PolyTrack for fast multi-object tracking and segme

Gaspar Faure 13 Sep 15, 2022
Election Exit Poll Prediction and U.S.A Presidential Speech Analysis using Machine Learning

Machine_Learning Election Exit Poll Prediction and U.S.A Presidential Speech Analysis using Machine Learning This project is based on 2 case-studies:

Avnika Mehta 1 Jan 27, 2022
This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm and CNN.

Vietnamese sign lagnuage recognition using MHI and CNN This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm

Phat Pham 3 Feb 24, 2022
Official PyTorch code for Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021)

Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) This repository is the official PyTorc

Jingyun Liang 139 Dec 29, 2022
Open-Ended Commonsense Reasoning (NAACL 2021)

Open-Ended Commonsense Reasoning Quick links: [Paper] | [Video] | [Slides] | [Documentation] This is the repository of the paper, Differentiable Open-

(Bill) Yuchen Lin 31 Oct 19, 2022
This repo contains the code for the paper "Efficient hierarchical Bayesian inference for spatio-temporal regression models in neuroimaging" that has been accepted to NeurIPS 2021.

Dugh-NeurIPS-2021 This repo contains the code for the paper "Efficient hierarchical Bayesian inference for spatio-temporal regression models in neuroi

Ali Hashemi 5 Jul 12, 2022
STRIVE: Scene Text Replacement In Videos

STRIVE: Scene Text Replacement In Videos Dataset Types: RoboText SynthText RealWorld videos RoboText : Videos of texts collected using navigation robo

15 Jul 11, 2022
This is a collection of our NAS and Vision Transformer work.

This is a collection of our NAS and Vision Transformer work.

Microsoft 828 Dec 28, 2022