Official implementation of paper Gradient Matching for Domain Generalization

Related tags

Deep Learningfish
Overview

Gradient Matching for Domain Generalisation

This is the official PyTorch implementation of Gradient Matching for Domain Generalisation. In our paper, we propose an inter-domain gradient matching (IDGM) objective that targets domain generalization by maximizing the inner product between gradients from different domains. To avoid computing the expensive second-order derivative of the IDGM objective, we derive a simpler first-order algorithm named Fish that approximates its optimization.

This repository contains code to reproduce the main results of our paper.

Dependencies

(Recommended) You can setup up conda environment with all required dependencies using environment.yml:

conda env create -f environment.yml
conda activate fish

Otherwise you can also install the following packages manually:

python=3.7.10
numpy=1.20.2
pytorch=1.8.1
torchaudio=0.8.1
torchvision=0.9.1
torch-cluster=1.5.9
torch-geometric=1.7.0
torch-scatter=2.0.6
torch-sparse=0.6.9
wilds=1.1.0
scikit-learn=0.24.2
scipy=1.6.3
seaborn=0.11.1
tqdm=4.61.0

Running Experiments

We offer options to train using our proposed method Fish or by using Empirical Risk Minimisation baseline. This can be specified by the --algorithm flag (either fish or erm).

CdSprites-N

We propose this simple shape-color dataset based on the dSprites dataset, which contains a collection of white 2D sprites of different shapes, scales, rotations and positions. The dataset contains N domains, where N can be specified. The goal is to classify the shape of the sprites, and there is a shape-color deterministic matching that is specific per domain. This way we have shape as the invariant feature and color as the spurious feature. On the test set, however, this correlation between color and shape is removed. See the image below for an illustration.

cdsprites

The CdSprites-N dataset can be downloaded here. After downloading, please extract the zip file to your preferred data dir (e.g. <your_data_dir>/cdsprites). The following command runs an experiment using Fish with number of domains N=15:

python main.py --dataset cdsprites --algorithm fish --data-dir <your_data_dir> --num-domains 15

The number of domains you can choose from are: N = 5, 10, 15, 20, 25, 30, 35, 40, 45, 50.

WILDS

We include the following 6 datasets from the WILDS benchmark: amazon, camelyon, civil, fmow, iwildcam, poverty. The datasets can be downloaded automatically to a specified data folder. For instance, to train with Fish on Amazon dataset, simply run:

python main.py --dataset amazon --algorithm fish --data-dir <your_data_dir>

This should automatically download the Amazon dataset to <your_data_dir>/wilds. Experiments on other datasets can be ran by the following commands:

python main.py --dataset camelyon --algorithm fish --data-dir <your_data_dir>
python main.py --dataset civil --algorithm fish --data-dir <your_data_dir>
python main.py --dataset fmow --algorithm fish --data-dir <your_data_dir>
python main.py --dataset iwildcam --algorithm fish --data-dir <your_data_dir>
python main.py --dataset poverty --algorithm fish --data-dir <your_data_dir>

Alternatively, you can also download the datasets to <your_data_dir>/wilds manually by following the instructions here. See current results on WILDS here: image

DomainBed

For experiments on datasets including CMNIST, RMNIST, VLCS, PACS, OfficeHome, TerraInc and DomainNet, we implemented Fish on the DomainBed benchmark (see here) and you can compare our algorithm against up to 20 SOTA baselines. See current results on DomainBed here:

image

Citation

If you make use of this code in your research, we would appreciate if you considered citing the paper that is most relevant to your work:

@article{shi2021gradient,
	title="Gradient Matching for Domain Generalization.",
	author="Yuge {Shi} and Jeffrey {Seely} and Philip H. S. {Torr} and N. {Siddharth} and Awni {Hannun} and Nicolas {Usunier} and Gabriel {Synnaeve}",
	journal="arXiv preprint arXiv:2104.09937",
	year="2021"}

Contributions

We welcome contributions via pull requests. Please email [email protected] or [email protected] for any question/request.

[CVPR22] Official codebase of Semantic Segmentation by Early Region Proxy.

RegionProxy Figure 2. Performance vs. GFLOPs on ADE20K val split. Semantic Segmentation by Early Region Proxy Yifan Zhang, Bo Pang, Cewu Lu CVPR 2022

Yifan 54 Nov 29, 2022
A Python framework for developing parallelized Computational Fluid Dynamics software to solve the hyperbolic 2D Euler equations on distributed, multi-block structured grids.

pyHype: Computational Fluid Dynamics in Python pyHype is a Python framework for developing parallelized Computational Fluid Dynamics software to solve

Mohamed Khalil 21 Nov 22, 2022
[NeurIPS 2021] Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data

Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data (NeurIPS 2021) This repository will provide the official PyTorch implementa

Liming Jiang 238 Nov 25, 2022
🛰️ Awesome Satellite Imagery Datasets

Awesome Satellite Imagery Datasets List of aerial and satellite imagery datasets with annotations for computer vision and deep learning. Newest datase

Christoph Rieke 3k Jan 03, 2023
An implementation of the AdaOPS (Adaptive Online Packing-based Search), which is an online POMDP Solver used to solve problems defined with the POMDPs.jl generative interface.

AdaOPS An implementation of the AdaOPS (Adaptive Online Packing-guided Search), which is an online POMDP Solver used to solve problems defined with th

9 Oct 05, 2022
MiraiML: asynchronous, autonomous and continuous Machine Learning in Python

MiraiML Mirai: future in japanese. MiraiML is an asynchronous engine for continuous & autonomous machine learning, built for real-time usage. Usage In

Arthur Paulino 25 Jul 27, 2022
PyTorch trainer and model for Sequence Classification

PyTorch-trainer-and-model-for-Sequence-Classification After cloning the repository, modify your training data so that the training data is a .csv file

NhanTieu 2 Dec 09, 2022
Code base of object detection

rmdet code base of object detection. 环境安装: 1. 安装conda python环境 - `conda create -n xxx python=3.7/3.8` - `conda activate xxx` 2. 运行脚本,自动安装pytorch1

3 Mar 08, 2022
.NET bindings for the Pytorch engine

TorchSharp TorchSharp is a .NET library that provides access to the library that powers PyTorch. It is a work in progress, but already provides a .NET

Matteo Interlandi 17 Aug 30, 2021
《A-CNN: Annularly Convolutional Neural Networks on Point Clouds》(2019)

A-CNN: Annularly Convolutional Neural Networks on Point Clouds Created by Artem Komarichev, Zichun Zhong, Jing Hua from Department of Computer Science

Artёm Komarichev 44 Feb 24, 2022
ruptures: change point detection in Python

Welcome to ruptures ruptures is a Python library for off-line change point detection. This package provides methods for the analysis and segmentation

Charles T. 1.1k Jan 03, 2023
[NeurIPS 2021] Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods

Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods Large Scale Learning on Non-Homophilous Graphs: New Benchmark

60 Jan 03, 2023
Neural Koopman Lyapunov Control

Neural-Koopman-Lyapunov-Control Code for our paper: Neural Koopman Lyapunov Control Requirements dReal4: v4.19.02.1 PyTorch: 1.2.0 The learning framew

Vrushabh Zinage 6 Dec 24, 2022
A Pytorch implementation of the multi agent deep deterministic policy gradients (MADDPG) algorithm

Multi-Agent-Deep-Deterministic-Policy-Gradients A Pytorch implementation of the multi agent deep deterministic policy gradients(MADDPG) algorithm This

Phil Tabor 159 Dec 28, 2022
MazeRL is an application oriented Deep Reinforcement Learning (RL) framework

MazeRL is an application oriented Deep Reinforcement Learning (RL) framework, addressing real-world decision problems. Our vision is to cover the complete development life cycle of RL applications ra

EnliteAI GmbH 222 Dec 24, 2022
Ros2-voiceroid2 - ROS2 wrapper package of VOICEROID2

ros2_voiceroid2 ROS2 wrapper package of VOICEROID2 Windows Only Installation Ins

Nkyoku 1 Jan 23, 2022
This is a Tensorflow implementation of Learning to See in the Dark in CVPR 2018

Learning-to-See-in-the-Dark This is a Tensorflow implementation of Learning to See in the Dark in CVPR 2018, by Chen Chen, Qifeng Chen, Jia Xu, and Vl

5.3k Jan 01, 2023
Dynamic Visual Reasoning by Learning Differentiable Physics Models from Video and Language (NeurIPS 2021)

VRDP (NeurIPS 2021) Dynamic Visual Reasoning by Learning Differentiable Physics Models from Video and Language Mingyu Ding, Zhenfang Chen, Tao Du, Pin

Mingyu Ding 36 Sep 20, 2022
Official Pytorch implementation of "Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021)

Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021) Official Pytorch implementation of Unbiased Classification

Youngkyu 17 Jan 01, 2023
Apply Graph Self-Supervised Learning methods to graph-level task(TUDataset, MolculeNet Datset)

Graphlevel-SSL Overview Apply Graph Self-Supervised Learning methods to graph-level task(TUDataset, MolculeNet Dataset). It is unified framework to co

JunSeok 8 Oct 15, 2021