PyTorch code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised DA

Related tags

Deep LearningSENTRY
Overview

PyTorch Code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation

Viraj Prabhu, Shivam Khare, Deeksha Kartik, Judy Hoffman

Many existing approaches for unsupervised domain adaptation (UDA) focus on adapting under only data distribution shift and offer limited success under additional cross-domain label distribution shift. Recent work based on self-training using target pseudolabels has shown promise, but on challenging shifts pseudolabels may be highly unreliable and using them for self-training may cause error accumulation and domain misalignment. We propose Selective Entropy Optimization via Committee Consistency (SENTRY), a UDA algorithm that judges the reliability of a target instance based on its predictive consistency under a committee of random image transformations. Our algorithm then selectively minimizes predictive entropy to increase confidence on highly consistent target instances, while maximizing predictive entropy to reduce confidence on highly inconsistent ones. In combination with pseudolabel-based approximate target class balancing, our approach leads to significant improvements over the state-of-the-art on 27/31 domain shifts from standard UDA benchmarks as well as benchmarks designed to stress-test adaptation under label distribution shift.

method

Table of Contents

Setup and Dependencies

  1. Create an anaconda environment with Python 3.6: conda create -n sentry python=3.6.8 and activate: conda activate sentry
  2. Navigate to the code directory: cd code/
  3. Install dependencies: pip install -r requirements.txt

And you're all set up!

Usage

Download data

Data for SVHN->MNIST is downloaded automatically via PyTorch. Data for other benchmarks can be downloaded from the following links. The splits used for our experiments are already included in the data/ folder):

  1. DomainNet
  2. OfficeHome
  3. VisDA2017 (only train and validation needed)

Pretrained checkpoints

To reproduce numbers reported in the paper, we include a a few pretrained checkpoints. We include checkpoints (source and adapted) for SVHN to MNIST (DIGITS) in the checkpoints directory. Source and adapted checkpoints for Clipart to Sketch adaptation (from DomainNet) and Real_World to Product adaptation (from OfficeHome RS-UT) can be downloaded from this link, and should be saved to the checkpoints/source and checkpoints/SENTRY directory as appropriate.

Train and adapt model

  • Natural label distribution shift: Adapt a model from to for a given (where benchmark may be DomainNet, OfficeHome, VisDA, or DIGITS), as follows:
python train.py --id <experiment_id> \
                --source <source> \
                --target <target> \
                --img_dir <image_directory> \
                --LDS_type <LDS_type> \
                --load_from_cfg True \
                --cfg_file 'config/<benchmark>/<cfg_file>.yml' \
                --use_cuda True

SENTRY hyperparameters are provided via a sentry.yml config file in the corresponding config/<benchmark> folder (On DIGITS, we also provide a config for baseline adaptation via DANN). The list of valid source/target domains per-benchmark are:

  • DomainNet: real, clipart, sketch, painting
  • OfficeHome_RS_UT: Real_World, Clipart, Product
  • OfficeHome: Real_World, Clipart, Product, Art
  • VisDA2017: visda_train, visda_test
  • DIGITS: Only svhn (source) to mnist (target) adaptation is currently supported.

Pass in the path to the parent folder containing dataset images via the --img_dir <name_of_directory> flag (eg. --img_dir '~/data/DomainNet'). Pass in the label distribution shift type via the --LDS_type flag: For DomainNet, OfficeHome (standard), and VisDA2017, pass in --LDS_type 'natural' (default). For OfficeHome RS-UT, pass in --LDS_type 'RS_UT'. For DIGITS, pass in --LDS_type as one of IF1, IF20, IF50, or IF100, to load a manually long-tailed target training split with a given imbalance factor (IF), as described in Table 4 of the paper.

To load a pretrained DA checkpoint instead of training your own, additionally pass --load_da True and --id <benchmark_name> to the script above. Finally, the training script will log performance metrics to the console (average and aggregate accuracy), and additionally plot and save some per-class performance statistics to the results/ folder.

Note: By default this code runs on GPU. To run on CPU pass: --use_cuda False

Reference

If you found this code useful, please consider citing:

@article{prabhu2020sentry
   author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy},
   title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation},
   year = {2020},
   journal = {arXiv preprint: 2012.11460},
}

Acknowledgements

We would like to thank the developers of PyTorch for building an excellent framework, in addition to the numerous contributors to all the open-source packages we use.

License

MIT

Image Completion with Deep Learning in TensorFlow

Image Completion with Deep Learning in TensorFlow See my blog post for more details and usage instructions. This repository implements Raymond Yeh and

Brandon Amos 1.3k Dec 23, 2022
Cookiecutter PyTorch Lightning

Cookiecutter PyTorch Lightning Instructions # install cookiecutter pip install cookiecutter

Mazen 8 Nov 06, 2022
CompilerGym is a library of easy to use and performant reinforcement learning environments for compiler tasks

CompilerGym is a library of easy to use and performant reinforcement learning environments for compiler tasks

Facebook Research 721 Jan 03, 2023
The Fundamental Clustering Problems Suite (FCPS) summaries 54 state-of-the-art clustering algorithms, common cluster challenges and estimations of the number of clusters as well as the testing for cluster tendency.

FCPS Fundamental Clustering Problems Suite The package provides over sixty state-of-the-art clustering algorithms for unsupervised machine learning pu

9 Nov 27, 2022
LieTransformer: Equivariant Self-Attention for Lie Groups

LieTransformer This repository contains the implementation of the LieTransformer used for experiments in the paper LieTransformer: Equivariant Self-At

OxCSML (Oxford Computational Statistics and Machine Learning) 50 Dec 28, 2022
coldcuts is an R package to automatically generate and plot segmentation drawings in R

coldcuts coldcuts is an R package that allows you to draw and plot automatically segmentations from 3D voxel arrays. The name is inspired by one of It

2 Sep 03, 2022
An implementation of the efficient attention module.

Efficient Attention An implementation of the efficient attention module. Description Efficient attention is an attention mechanism that substantially

Shen Zhuoran 194 Dec 15, 2022
🌈 PyTorch Implementation for EMNLP'21 Findings "Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer"

SGLKT-VisDial Pytorch Implementation for the paper: Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer Gi-Cheon Kang, Junseok P

Gi-Cheon Kang 9 Jul 05, 2022
The Wearables Development Toolkit - a development environment for activity recognition applications with sensor signals

Wearables Development Toolkit (WDK) The Wearables Development Toolkit (WDK) is a framework and set of tools to facilitate the iterative development of

Juan Haladjian 114 Nov 27, 2022
Self-Supervised Pre-Training for Transformer-Based Person Re-Identification

Self-Supervised Pre-Training for Transformer-Based Person Re-Identification [pdf] The official repository for Self-Supervised Pre-Training for Transfo

Hao Luo 116 Jan 04, 2023
Watch faces morph into each other with StyleGAN 2, StyleGAN, and DCGAN!

FaceMorpher FaceMorpher is an innovative project to get a unique face morph (or interpolation for geeks) on a website. Yes, this means you can see fac

Anish 9 Jun 24, 2022
Gradient representations in ReLU networks as similarity functions

Gradient representations in ReLU networks as similarity functions by Dániel Rácz and Bálint Daróczy. This repo contains the python code related to our

1 Oct 08, 2021
A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

TorchRL Disclaimer This library is not officially released yet and is subject to change. The features are available before an official release so that

Meta Research 860 Jan 07, 2023
This is the reference implementation for "Coresets via Bilevel Optimization for Continual Learning and Streaming"

Coresets via Bilevel Optimization This is the reference implementation for "Coresets via Bilevel Optimization for Continual Learning and Streaming" ht

Zalán Borsos 51 Dec 30, 2022
SimBERT升级版(SimBERTv2)!

RoFormer-Sim RoFormer-Sim,又称SimBERTv2,是我们之前发布的SimBERT模型的升级版。 介绍 https://kexue.fm/archives/8454 训练 tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6 下载

318 Dec 31, 2022
Official implementation of "Articulation Aware Canonical Surface Mapping"

Articulation-Aware Canonical Surface Mapping Nilesh Kulkarni, Abhinav Gupta, David F. Fouhey, Shubham Tulsiani Paper Project Page Requirements Python

Nilesh Kulkarni 56 Dec 16, 2022
Source code of our BMVC 2021 paper: AniFormer: Data-driven 3D Animation with Transformer

AniFormer This is the PyTorch implementation of our BMVC 2021 paper AniFormer: Data-driven 3D Animation with Transformer. Haoyu Chen, Hao Tang, Nicu S

24 Nov 02, 2022
Implementation of Artificial Neural Network Algorithm

Artificial Neural Network This repository contain implementation of Artificial Neural Network Algorithm in several programming languanges and framewor

Resha Dwika Hefni Al-Fahsi 1 Sep 14, 2022
PyTorch reimplementation of Diffusion Models

PyTorch pretrained Diffusion Models A PyTorch reimplementation of Denoising Diffusion Probabilistic Models with checkpoints converted from the author'

Patrick Esser 265 Jan 01, 2023
The 7th edition of NTIRE: New Trends in Image Restoration and Enhancement workshop will be held on June 2022 in conjunction with CVPR 2022.

NTIRE 2022 - Image Inpainting Challenge Important dates 2022.02.01: Release of train data (input and output images) and validation data (only input) 2

Andrés Romero 37 Nov 27, 2022