Self-supervised learning optimally robust representations for domain generalization.

Overview

OptDom: Learning Optimal Representations for Domain Generalization

This repository contains the official implementation for Optimal Representations for Covariate Shift️. Our paper theoretically characterizes the minimal sufficient representations for optimal domain generalization (DG) under covariate shift and derives practical self-supervised learning (SSL) objectives for learning such representations.

We provide code for reproducing our main results with contribution highlights:

  • Finetuning pretrained SSL models (CLIP) to be superior robust DG models ️[minimal example]
  • A novel contrastive adversarial domain bottleneck for learning domain-invariant representations ️[implementation]

Setup

  1. Install PyTorch 1.7.1 and CLIP following the instructions.
  2. Install other packages: pip install -r requirements.txt.

Finetune & Evaluate CLIP on DomainBed

Our paper derives SSL objectives for learning optimally robust representations and gives insights into the superior robustness of CLIP (Sec 4). Here we include the code for finetuning CLIP with our proposed objectives and evaluating on the DomainBed benchmark, which reproduces our experiments in Sec 6.2.

The implementation is included in DomainBed directory which is highly based on the DomainBed repo. The CLIP based models are implemented in domainbed/clip_algorithms.py, and the domain bottlenecks are in domainbed/bottlenecks.py. The training script for finetuning CLIP with bottlenecks is domainbed/scripts/train_clip.py.

Preparation

Move to the DomainBed directory and download the datasets:

python -m domainbed.scripts.download --data_dir ./datasets/

By default, we download the datasets: PACS, VLCS, OfficeHome, TerraIncognita, DomainNet.

Launch a single run

If you want to launch a single run for debugging, run with command:

bash run_debug.sh

The key arguments include:

  • --dataset: dataset for finetuning and evaluation.
  • --algorithm: algorithms implemented with CLIP, see domainbed/clip_algorithms.py.
  • --test_envs: list of left-out environments for testing, others used for training/finetuning.
  • --hparams: JSON-serialized hyperprameter dict, see domainbed/hparams_registry.py for list of all hyperprameters.

Note that the result of a single run could be very sensitive to hyperparameters and random seed, we recommend to launch a sweep over hyperparameters and random seeds as in DomainBed.

Launch a sweep with tuning

To launch a sweep, run with command:

bash run_sweep_clip.sh

A sweep over 10 hyperparameters and 5 random seeds is launched for each dataset and algorithm. By default, the CLIP-RN50 model is used, and you can also run with other models by changing the clip_model argument, e.g., ViT-B/32 for CLIP-ViT-B/32. Also to launch a sweep, you need to select or implement a command launcher in domainbed/command_launchers.py by setting the launcher argument. If you are using slurm, we already implement a slurm launcher that you can adapt from.

After the sweep is finished, you can collect result with the notebook collect_clip_results.ipynb. Note that the results may be slightly different from the paper due to code cleaning.

(Optional) Run CAD in DomainBed setup

You can also evaluate our proposed (conditional) CAD bottleneck in the DomainBed setup where a ResNet-50 is end-to-end trained on source domains and evaluated on a left-out target domain. We include the implementation in domainbed/algorithms.py, which you can run with command:

bash run_sweep_e2e_dombed.sh

Also you can collect result with the notebook collect_e2e_results.ipynb. Note that as the claim of our paper, the algorithms in this setup lack access to the information of the target domain, so we don't expect our bottlenecks and other algorithms to necessarily outperform ERM. However, our CAD bottleneck does lead to consistent improvement surprisingly.

Finetune CLIP on LAION-400M

Coming soon!

Minimal Code for Custom Finetuning

If you want to finetune CLIP on your dataset with our bottlenecks, we provide the minimal code example:

import torch
from torch.utils.data import DataLoader, TensorDataset
import clip
from tqdm import tqdm

from domainbed import hparams_registry
from domainbed import algorithms


# 1. Determine whether you do supervised or contrastive finetuning:
#       - True: use a cross-entropy loss with a supervised dataset
#       - False: use a contrastive loss with a text-image-pair dataset
supervised_funetuning = True

if supervised_funetuning:
    loss_name = "Sup"
    dataset_name = "my suervised dataset"
else:
    loss_name = "Contrast"
    dataset_name = "my text-image pair dataset"


# 2. Determine the bottleneck you want to use with different properties
bottleneck_name = "CondCAD"  # Ent, CAD, CondCAD
algorithm_name = loss_name + "CLIPBottleneck" + bottleneck_name


# 3. Set hyperparameters, you can also change the hyperparameter dict and default values
hparams = hparams_registry.default_hparams(algorithm_name, dataset_name)


# 4. Load pretrained CLIP models
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

pretrained, preprocess = clip.load(hparams['clip_model'], device, jit=False)


# 5. Load your dataset, you  dataset should have the form:
#       - (image, label) for supervised finetuning
#       - (image, text) for contrastive finetuning
#    Remember to use the CLIP preprocessing function for image transformation,
#       and your dataset should be a list of sub-datasets from different domains (singleton for a single domain)
dataset = load_your_dataset(dataset_name, preprocess)
num_envs = len(dataset)
num_classes = dataset.num_classes  # dummy for text-image-pair dataset


# 6. Featurize your dataset with CLIP models

def get_clip_feature(clip_model, x, y):
    """Compute CLIP features"""
    with torch.no_grad():
        z = clip_model.encode_image(x).float()
        if not supervised_funetuning:  # `y` is a batch of texts that should be tokenized
            y = clip_model.encode_text(clip.tokenize(y)).float()
    return z, y

def clip_featurize_data(clip_model, dataset, device):
    """Featurize a dataset"""
    Z, Y = [], []
    for x, y in tqdm(DataLoader(dataset, batch_size=512, num_workers=4)):
        z, y = get_clip_feature(clip_model, x.to(device), y.to(device))
        Z += [z.cpu()]
        Y += [y.cpu()]
    return TensorDataset(torch.cat(Z), torch.cat(Y))

def clip_precompute_splits(clip_model, splits, device):
    _splits = []
    for ds in splits:
        _splits.append(clip_featurize_data(clip_model, ds, device))
    return _splits


dataset = clip_precompute_splits(pretrained, dataset, device)
train_loaders = [DataLoader(
    dataset=env,
    batch_size=hparams['batch_size'],
    num_workers=4)
    for i, env in enumerate(dataset)]
train_minibatches_iterator = zip(*train_loaders)
steps_per_epoch = int(min([len(env) / hparams['batch_size'] for env in dataset]))
n_steps = hparams['max_step']


# 7. Initialize the model:
algorithm_class = algorithms.get_algorithm_class(algorithm_name)
algorithm = algorithm_class(pretrained.visual.output_dim, num_classes, num_envs, hparams, pretrained, None)
algorithm.to(device)
algorithm.train()


# 8. Finetune the model:
for step in range(n_steps):
    minibatches_device = [(x.to(device), y.to(device)) for x, y in next(train_minibatches_iterator)]
    algorithm.adjust_lr(step, n_steps, steps_per_epoch)
    step_vals = algorithm.update(minibatches_device, None)

Cite

If you find this work relevant to your work, please cite our paper:

@article{ruan2021optdom,
  title={Optimal Representations for Covariate Shift},
  author={Ruan, Yangjun and  Dubois, Yann and Maddison, Chris J},
  journal={arXiv preprint arXiv:2201.00057},
  year={2022},
}

Acknowledgement

Our code is based on:

Owner
Yangjun Ruan
Ph.D. student @ UofT & Vector Previously undergrad @ ZJU
Yangjun Ruan
Neurons Dataset API - The official dataloader and visualization tools for Neurons Datasets.

Neurons Dataset API - The official dataloader and visualization tools for Neurons Datasets. Introduction We propose our dataloader API for loading and

1 Nov 19, 2021
Implementation of the ALPHAMEPOL algorithm, presented in Unsupervised Reinforcement Learning in Multiple Environments.

ALPHAMEPOL This repository contains the implementation of the ALPHAMEPOL algorithm, presented in Unsupervised Reinforcement Learning in Multiple Envir

3 Dec 23, 2021
A task Provided by A respective Artenal Ai and Ml based Company to complete it

A task Provided by A respective Alternal Ai and Ml based Company to complete it .

Parth Madan 1 Jan 25, 2022
Geometric Vector Perceptrons --- a rotation-equivariant GNN for learning from biomolecular structure

Geometric Vector Perceptron Implementation of equivariant GVP-GNNs as described in Learning from Protein Structure with Geometric Vector Perceptrons b

Dror Lab 142 Dec 29, 2022
This project uses reinforcement learning on stock market and agent tries to learn trading. The goal is to check if the agent can learn to read tape. The project is dedicated to hero in life great Jesse Livermore.

Reinforcement-trading This project uses Reinforcement learning on stock market and agent tries to learn trading. The goal is to check if the agent can

Deepender Singla 1.4k Dec 22, 2022
Interactive Visualization to empower domain experts to align ML model behaviors with their knowledge.

An interactive visualization system designed to helps domain experts responsibly edit Generalized Additive Models (GAMs). For more information, check

InterpretML 83 Jan 04, 2023
Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face Manipulation" published in CVPR 2020.

FFD Source Code Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face M

88 Nov 22, 2022
FairMOT for Multi-Class MOT using YOLOX as Detector

FairMOT-X Project Overview FairMOT-X is a multi-class multi object tracker, which has been tailored for training on the BDD100K MOT Dataset. It makes

Jonathan Tan 33 Dec 28, 2022
A Python reference implementation of the CF data model

cfdm A Python reference implementation of the CF data model. References Compliance with FAIR principles Documentation https://ncas-cms.github.io/cfdm

NCAS CMS 25 Dec 13, 2022
A Lighting Pytorch Framework for Recommendation System, Easy-to-use and Easy-to-extend.

Torch-RecHub A Lighting Pytorch Framework for Recommendation Models, Easy-to-use and Easy-to-extend. 安装 pip install torch-rechub 主要特性 scikit-learn风格易用

Mincai Lai 67 Jan 04, 2023
Rafael Project- Classifying rockets to different types using data science algorithms.

Rocket-Classify Rafael Project- Classifying rockets to different types using data science algorithms. In this project we received data base with data

Hadassah Engel 5 Sep 18, 2021
Julia package for multiway (inverse) covariance estimation.

TensorGraphicalModels TensorGraphicalModels.jl is a suite of Julia tools for estimating high-dimensional multiway (tensor-variate) covariance and inve

Wayne Wang 3 Sep 23, 2022
Implementation of " SESS: Self-Ensembling Semi-Supervised 3D Object Detection" (CVPR2020 Oral)

SESS: Self-Ensembling Semi-Supervised 3D Object Detection Created by Na Zhao from National University of Singapore Introduction This repository contai

125 Dec 23, 2022
Time should be taken seer-iously

TimeSeers seers - (Noun) plural form of seer - A person who foretells future events by or as if by supernatural means TimeSeers is an hierarchical Bay

279 Dec 26, 2022
SporeAgent: Reinforced Scene-level Plausibility for Object Pose Refinement

SporeAgent: Reinforced Scene-level Plausibility for Object Pose Refinement This repository implements the approach described in SporeAgent: Reinforced

Dominik Bauer 5 Jan 02, 2023
GANSketchingJittor - Implementation of Sketch Your Own GAN in Jittor

GANSketching in Jittor Implementation of (Sketch Your Own GAN) in Jittor(计图). Or

Bernard Tan 10 Jul 02, 2022
ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

Ibai Gorordo 18 Nov 06, 2022
Pytorch implementation of XRD spectral identification from COD database

XRDidentifier Pytorch implementation of XRD spectral identification from COD database. Details will be explained in the paper to be submitted to NeurI

Masaki Adachi 4 Jan 07, 2023
​TextWorld is a sandbox learning environment for the training and evaluation of reinforcement learning (RL) agents on text-based games.

TextWorld A text-based game generator and extensible sandbox learning environment for training and testing reinforcement learning (RL) agents. Also ch

Microsoft 983 Dec 23, 2022