Official implementation of paper Gradient Matching for Domain Generalization

Related tags

Deep Learningfish
Overview

Gradient Matching for Domain Generalisation

This is the official PyTorch implementation of Gradient Matching for Domain Generalisation. In our paper, we propose an inter-domain gradient matching (IDGM) objective that targets domain generalization by maximizing the inner product between gradients from different domains. To avoid computing the expensive second-order derivative of the IDGM objective, we derive a simpler first-order algorithm named Fish that approximates its optimization.

This repository contains code to reproduce the main results of our paper.

Dependencies

(Recommended) You can setup up conda environment with all required dependencies using environment.yml:

conda env create -f environment.yml
conda activate fish

Otherwise you can also install the following packages manually:

python=3.7.10
numpy=1.20.2
pytorch=1.8.1
torchaudio=0.8.1
torchvision=0.9.1
torch-cluster=1.5.9
torch-geometric=1.7.0
torch-scatter=2.0.6
torch-sparse=0.6.9
wilds=1.1.0
scikit-learn=0.24.2
scipy=1.6.3
seaborn=0.11.1
tqdm=4.61.0

Running Experiments

We offer options to train using our proposed method Fish or by using Empirical Risk Minimisation baseline. This can be specified by the --algorithm flag (either fish or erm).

CdSprites-N

We propose this simple shape-color dataset based on the dSprites dataset, which contains a collection of white 2D sprites of different shapes, scales, rotations and positions. The dataset contains N domains, where N can be specified. The goal is to classify the shape of the sprites, and there is a shape-color deterministic matching that is specific per domain. This way we have shape as the invariant feature and color as the spurious feature. On the test set, however, this correlation between color and shape is removed. See the image below for an illustration.

cdsprites

The CdSprites-N dataset can be downloaded here. After downloading, please extract the zip file to your preferred data dir (e.g. <your_data_dir>/cdsprites). The following command runs an experiment using Fish with number of domains N=15:

python main.py --dataset cdsprites --algorithm fish --data-dir <your_data_dir> --num-domains 15

The number of domains you can choose from are: N = 5, 10, 15, 20, 25, 30, 35, 40, 45, 50.

WILDS

We include the following 6 datasets from the WILDS benchmark: amazon, camelyon, civil, fmow, iwildcam, poverty. The datasets can be downloaded automatically to a specified data folder. For instance, to train with Fish on Amazon dataset, simply run:

python main.py --dataset amazon --algorithm fish --data-dir <your_data_dir>

This should automatically download the Amazon dataset to <your_data_dir>/wilds. Experiments on other datasets can be ran by the following commands:

python main.py --dataset camelyon --algorithm fish --data-dir <your_data_dir>
python main.py --dataset civil --algorithm fish --data-dir <your_data_dir>
python main.py --dataset fmow --algorithm fish --data-dir <your_data_dir>
python main.py --dataset iwildcam --algorithm fish --data-dir <your_data_dir>
python main.py --dataset poverty --algorithm fish --data-dir <your_data_dir>

Alternatively, you can also download the datasets to <your_data_dir>/wilds manually by following the instructions here. See current results on WILDS here: image

DomainBed

For experiments on datasets including CMNIST, RMNIST, VLCS, PACS, OfficeHome, TerraInc and DomainNet, we implemented Fish on the DomainBed benchmark (see here) and you can compare our algorithm against up to 20 SOTA baselines. See current results on DomainBed here:

image

Citation

If you make use of this code in your research, we would appreciate if you considered citing the paper that is most relevant to your work:

@article{shi2021gradient,
	title="Gradient Matching for Domain Generalization.",
	author="Yuge {Shi} and Jeffrey {Seely} and Philip H. S. {Torr} and N. {Siddharth} and Awni {Hannun} and Nicolas {Usunier} and Gabriel {Synnaeve}",
	journal="arXiv preprint arXiv:2104.09937",
	year="2021"}

Contributions

We welcome contributions via pull requests. Please email [email protected] or [email protected] for any question/request.

An implementation for the ICCV 2021 paper Deep Permutation Equivariant Structure from Motion.

Deep Permutation Equivariant Structure from Motion Paper | Poster This repository contains an implementation for the ICCV 2021 paper Deep Permutation

72 Dec 27, 2022
tmm_fast is a lightweight package to speed up optical planar multilayer thin-film device computation.

tmm_fast tmm_fast or transfer-matrix-method_fast is a lightweight package to speed up optical planar multilayer thin-film device computation. It is es

26 Dec 11, 2022
Remote sensing change detection using PaddlePaddle

Change Detection Laboratory Developing and benchmarking deep learning-based remo

Lin Manhui 15 Sep 23, 2022
PyTorch implementation for the Neuro-Symbolic Sudoku Solver leveraging the power of Neural Logic Machines (NLM)

Neuro-Symbolic Sudoku Solver PyTorch implementation for the Neuro-Symbolic Sudoku Solver leveraging the power of Neural Logic Machines (NLM). Please n

Ashutosh Hathidara 60 Dec 10, 2022
Chatbot in 200 lines of code using TensorLayer

Seq2Seq Chatbot This is a 200 lines implementation of Twitter/Cornell-Movie Chatbot, please read the following references before you read the code: Pr

TensorLayer Community 820 Dec 17, 2022
Generative Adversarial Text-to-Image Synthesis

###Generative Adversarial Text-to-Image Synthesis Scott Reed, Zeynep Akata, Xinchen Yan, Lajanugen Logeswaran, Bernt Schiele, Honglak Lee This is the

Scott Ellison Reed 883 Dec 31, 2022
Picasso: A CUDA-based Library for Deep Learning over 3D Meshes

The Picasso Library is intended for complex real-world applications with large-scale surfaces, while it also performs impressively on the small-scale applications over synthetic shape manifolds. We h

97 Dec 01, 2022
rastrainer is a QGIS plugin to training remote sensing semantic segmentation model based on PaddlePaddle.

rastrainer rastrainer is a QGIS plugin to training remote sensing semantic segmentation model based on PaddlePaddle. UI TODO Init UI. Add Block. Add l

deepbands 5 Mar 04, 2022
ADOP: Approximate Differentiable One-Pixel Point Rendering

ADOP: Approximate Differentiable One-Pixel Point Rendering Abstract: We present a novel point-based, differentiable neural rendering pipeline for scen

Darius Rückert 1.9k Jan 06, 2023
ColossalAI-Benchmark - Performance benchmarking with ColossalAI

Benchmark for Tuning Accuracy and Efficiency Overview The benchmark includes our

HPC-AI Tech 31 Oct 07, 2022
[ICCV 2021] Self-supervised Monocular Depth Estimation for All Day Images using Domain Separation

ADDS-DepthNet This is the official implementation of the paper Self-supervised Monocular Depth Estimation for All Day Images using Domain Separation I

LIU_LINA 52 Nov 24, 2022
Official repository of the AAAI'2022 paper "Contrast and Generation Make BART a Good Dialogue Emotion Recognizer"

CoG-BART Contrast and Generation Make BART a Good Dialogue Emotion Recognizer Quick Start: To run the model on test sets of four datasets, Download th

39 Dec 24, 2022
Asymmetric metric learning for knowledge transfer

Asymmetric metric learning This is the official code that enables the reproduction of the results from our paper: Asymmetric metric learning for knowl

20 Dec 06, 2022
This is the official PyTorch implementation of our paper: "Artistic Style Transfer with Internal-external Learning and Contrastive Learning".

Artistic Style Transfer with Internal-external Learning and Contrastive Learning This is the official PyTorch implementation of our paper: "Artistic S

51 Dec 20, 2022
https://sites.google.com/cornell.edu/recsys2021tutorial

Counterfactual Learning and Evaluation for Recommender Systems (RecSys'21 Tutorial) Materials for "Counterfactual Learning and Evaluation for Recommen

yuta-saito 45 Nov 10, 2022
A pytorch-based real-time segmentation model for autonomous driving

CFPNet: Channel-Wise Feature Pyramid for Real-Time Semantic Segmentation This project contains the Pytorch implementation for the proposed CFPNet: pap

342 Dec 22, 2022
Pytorch implementation of our method for regularizing nerual radiance fields for few-shot neural volume rendering.

InfoNeRF: Ray Entropy Minimization for Few-Shot Neural Volume Rendering Pytorch implementation of our method for regularizing nerual radiance fields f

106 Jan 06, 2023
The project covers common metrics for super-resolution performance evaluation.

Super-Resolution Performance Evaluation Code The project covers common metrics for super-resolution performance evaluation. Metrics support The script

xmy 10 Aug 03, 2022
Disentangled Face Attribute Editing via Instance-Aware Latent Space Search, accepted by IJCAI 2021.

Instance-Aware Latent-Space Search This is a PyTorch implementation of the following paper: Disentangled Face Attribute Editing via Instance-Aware Lat

67 Dec 21, 2022
Referring Video Object Segmentation

Awesome-Referring-Video-Object-Segmentation Welcome to starts ⭐ & comments 💹 & sharing 😀 !! - 2021.12.12: Recent papers (from 2021) - welcome to ad

Explorer 57 Dec 11, 2022