This codebase is the official implementation of Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization (NeurIPS2021, Spotlight)

Related tags

Deep LearningT3A
Overview

Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization

This codebase is the official implementation of Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization (NeurIPS2021, Spotlight). This codebase is mainly based on DomainBed, with following modifications:

  • enable to use various backbone networks including Big Transfer (BiT), Vision Transformers (ViT, DeiT, HViT), and MLP-Mixer.
  • enable to test test-time adaptation method (T3A and Tent).

Installation

CUDA/Python

git clone [email protected]:matsuolab/Domainbed_contrib.git
cd Domainbed_contrib/docker
docker build -t {image_name} .
docker run -it -h `hostname` --runtime=nvidia -v /path/to/Domainbed_contrib /path/to/anyware --shm-size=40gb --name {container_name} {image_name}

Python libralies

We use pipenv for package management.

cd /path/to/Domainbed_contrib
pip install pipenv
pipenv install
pipenv shell
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html

Quick start

(1) Downlload the datasets

python -m domainbed.scripts.download --data_dir=/my/datasets/path --dataset pacs

Note: change --dataset pacs for downloading other datasets (e.g., vlcs, office_home, terra_incognita).

(2) Train a model on source domains

python -m domainbed.scripts.train\
       --data_dir /my/datasets/path\
       --output_dir /my/pretrain/path\
       --algorithm ERM\
       --dataset PACS\
       --hparams "{\"backbone\": \"resnet50\"}" 

This scripts will produce new directory /my/pretrain/path, which include the full training log.

Note: change --dataset PACS for training on other datasets (e.g., VLCS, OfficeHome, TerraIncognita).

Note: change --hparams "{\"backbone\": \"resnet50\"}" for using other backbones (e.g., resnet18, ViT-B16, HViT).

(3) Evaluate model with test time adaptation (Table 1, Table 2, Figure 2)

python -m domainbed.scripts.unsupervised_adaptation\
       --input_dir=/my/pretrain/path\
       --adapt_algorithm=T3A

This scripts will produce a new file in /my/pretrain/path, whose name is results_{adapt_algorithm}.jsonl.

Note: change --adapt_algorithm=T3A for using other test time adaptation methods (T3A, Tent, or TentClf).

(4) Evaluate model with fine-tuning classifier(Figure 1)

python -m domainbed.scripts.supervised_adaptation\
       --input_dir=/my/pretrain/path\
       --ft_mode=clf

This scripts will produce a new file in /my/pretrain/path, whose name is results_{ft_mode}.jsonl.

Available backbones

  • resnet18
  • resnet50
  • BiT-M-R50x3
  • BiT-M-R101x3
  • BiT-M-R152x2
  • ViT-B16
  • ViT-L16
  • DeiT
  • Hybrid ViT (HViT)
  • MLP-Mixer (Mixer-L16)

Reproducing results

Table 1 and Figure 2 (Tuned ERM and CORAL)

You can use scripts/hparam_search.sh. Specifically, for each dataset and base algorithm, you can just type a following command.

sh scripts/hparam_search.sh resnet50 PACS ERM

Note that, it automatically starts 240 jobs, and take many times to finish.

Table 2 and Figure 1 (ERM with various backbone)

You can use scripts/launch.sh. Specifically, for each backbone, you can just type following three commands.

sh scripts/launch.sh pretrain resnet50 10 3 local
sh scripts/launch.sh sup resnet50 10 3 local
sh scripts/launch.sh unsup resnet50 10 3 local

Other results

For table 1, we used scores reported by In Search of Lost Domain Generalization. Full results for the reported scores in LaTeX format available here. Note: We only used scores for VLCS, PACS, OfficeHome, and TerraIncognita. We used the resutls with IIDAccuracySelectionMethod.

License

This source code is released under the MIT license, included here.

A simple and lightweight genetic algorithm for optimization of any machine learning model

geneticml This package contains a simple and lightweight genetic algorithm for optimization of any machine learning model. Installation Use pip to ins

Allan Barcelos 8 Aug 10, 2022
Udacity's CS101: Intro to Computer Science - Building a Search Engine

Udacity's CS101: Intro to Computer Science - Building a Search Engine All soluti

Phillip 0 Feb 26, 2022
Si Adek Keras is software VR dangerous object detection.

Si Adek Python Keras Sistem Informasi Deteksi Benda Berbahaya Keras Python. Version 1.0 Developed by Ananda Rauf Maududi. Developed date: 24 November

Ananda Rauf 1 Dec 21, 2021
Semi-supervised Semantic Segmentation with Directional Context-aware Consistency (CVPR 2021)

Semi-supervised Semantic Segmentation with Directional Context-aware Consistency (CAC) Xin Lai*, Zhuotao Tian*, Li Jiang, Shu Liu, Hengshuang Zhao, Li

Jia Research Lab 137 Dec 14, 2022
This repository contains the implementation of the following paper: Cross-Descriptor Visual Localization and Mapping

Cross-Descriptor Visual Localization and Mapping This repository contains the implementation of the following paper: "Cross-Descriptor Visual Localiza

Mihai Dusmanu 81 Oct 06, 2022
This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

ICCV Workshop 2021 VTGAN This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

Sharif Amit Kamran 25 Dec 08, 2022
The repository forked from NVlabs uses our data. (Differentiable rasterization applied to 3D model simplification tasks)

nvdiffmodeling [origin_code] Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Autom

Qiujie (Jay) Dong 2 Oct 31, 2022
Differentiable molecular simulation of proteins with a coarse-grained potential

Differentiable molecular simulation of proteins with a coarse-grained potential This repository contains the learned potential, simulation scripts and

UCL Bioinformatics Group 44 Dec 10, 2022
Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Tom-R.T.Kvalvaag 2 Dec 17, 2021
This repository is the official implementation of Open Rule Induction. This paper has been accepted to NeurIPS 2021.

Open Rule Induction This repository is the official implementation of Open Rule Induction. This paper has been accepted to NeurIPS 2021. Abstract Rule

Xingran Chen 16 Nov 14, 2022
DecoupledNet is semantic segmentation system which using heterogeneous annotations

DecoupledNet: Decoupled Deep Neural Network for Semi-supervised Semantic Segmentation Created by Seunghoon Hong, Hyeonwoo Noh and Bohyung Han at POSTE

Hyeonwoo Noh 74 Sep 22, 2021
CLDF dataset derived from Robbeets et al.'s "Triangulation Supports Agricultural Spread" from 2021

CLDF dataset derived from Robbeets et al.'s "Triangulation Supports Agricultural Spread" from 2021 How to cite If you use these data please cite the o

Digital Linguistics 2 Dec 20, 2021
This is an official implementation of "Polarized Self-Attention: Towards High-quality Pixel-wise Regression"

Polarized Self-Attention: Towards High-quality Pixel-wise Regression This is an official implementation of: Huajun Liu, Fuqiang Liu, Xinyi Fan and Don

DeLightCMU 212 Jan 08, 2023
Read and write layered TIFF ImageSourceData and ImageResources tags

Read and write layered TIFF ImageSourceData and ImageResources tags Psdtags is a Python library to read and write the Adobe Photoshop(r) specific Imag

Christoph Gohlke 4 Feb 05, 2022
QuakeLabeler is a Python package to create and manage your seismic training data, processes, and visualization in a single place — so you can focus on building the next big thing.

QuakeLabeler Quake Labeler was born from the need for seismologists and developers who are not AI specialists to easily, quickly, and independently bu

Hao Mai 15 Nov 04, 2022
ALL Snow Removed: Single Image Desnowing Algorithm Using Hierarchical Dual-tree Complex Wavelet Representation and Contradict Channel Loss (HDCWNet)

ALL Snow Removed: Single Image Desnowing Algorithm Using Hierarchical Dual-tree Complex Wavelet Representation and Contradict Channel Loss (HDCWNet) (

Wei-Ting Chen 49 Dec 27, 2022
Network Compression via Central Filter

Network Compression via Central Filter Environments The code has been tested in the following environments: Python 3.8 PyTorch 1.8.1 cuda 10.2 torchsu

2 May 12, 2022
Keras-tensorflow implementation of Fully Convolutional Networks for Semantic Segmentation(Unfinished)

Keras-FCN Fully convolutional networks and semantic segmentation with Keras. Models Models are found in models.py, and include ResNet and DenseNet bas

645 Dec 29, 2022
mPose3D, a mmWave-based 3D human pose estimation model.

mPose3D, a mmWave-based 3D human pose estimation model.

KylinChen 35 Nov 08, 2022
[ICCV 2021] Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages

Discriminative Region-based Multi-Label Zero-Shot Learning (ICCV 2021) [arXiv][Project page coming soon] Sanath Narayan*, Akshita Gupta*, Salman Kh

Akshita Gupta 54 Nov 21, 2022