CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

Related tags

Deep LearningCDTrans
Overview

CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation [arxiv]

This is the official repository for CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

Introduction

Unsupervised domain adaptation (UDA) aims to transfer knowledge learned from a labeled source domain to a different unlabeled target domain. Most existing UDA methods focus on learning domain-invariant feature representation, either from the domain level or category level, using convolution neural networks (CNNs)-based frameworks. With the success of Transformer in various tasks, we find that the cross-attention in Transformer is robust to the noisy input pairs for better feature alignment, thus in this paper Transformer is adopted for the challenging UDA task. Specifically, to generate accurate input pairs, we design a two-way center-aware labeling algorithm to produce pseudo labels for target samples. Along with the pseudo labels, a weight-sharing triple-branch transformer framework is proposed to apply self-attention and cross-attention for source/target feature learning and source-target domain alignment, respectively. Such design explicitly enforces the framework to learn discriminative domain-specific and domain-invariant representations simultaneously. The proposed method is dubbed CDTrans (cross-domain transformer), and it provides one of the first attempts to solve UDA tasks with a pure transformer solution. Extensive experiments show that our proposed method achieves the best performance on all public UDA datasets including Office-Home, Office-31, VisDA-2017, and DomainNet.

framework

Results

Table 1 [UDA results on Office-31]

Methods Avg. A->D A->W D->A D->W W->A W->D
Baseline(DeiT-S) 86.7 87.6 86.9 74.9 97.7 73.5 99.6
model model model
CDTrans(DeiT-S) 90.4 94.6 93.5 78.4 98.2 78 99.6
model model model model model model
Baseline(DeiT-B) 88.8 90.8 90.4 76.8 98.2 76.4 100
model model model
CDTrans(DeiT-B) 92.6 97 96.7 81.1 99 81.9 100
model model model model model model

Table 2 [UDA results on Office-Home]

Methods Avg. Ar->Cl Ar->Pr Ar->Re Cl->Ar Cl->Pr Cl->Re Pr->Ar Pr->Cl Pr->Re Re->Ar Re->Cl Re->Pr
Baseline(DeiT-S) 69.8 55.6 73 79.4 70.6 72.9 76.3 67.5 51 81 74.5 53.2 82.7
model model model model
CDTrans(DeiT-S) 74.7 60.6 79.5 82.4 75.6 81.0 82.3 72.5 56.7 84.4 77.0 59.1 85.5
model model model model model model model model model model model model
Baseline(DeiT-B) 74.8 61.8 79.5 84.3 75.4 78.8 81.2 72.8 55.7 84.4 78.3 59.3 86
model model model model
CDTrans(DeiT-B) 80.5 68.8 85 86.9 81.5 87.1 87.3 79.6 63.3 88.2 82 66 90.6
model model model model model model model model model model model model

Table 3 [UDA results on VisDA-2017]

Methods Per-class plane bcycl bus car horse knife mcycl person plant sktbrd train truck
Baseline(DeiT-B) 67.3 (model) 98.1 48.1 84.6 65.2 76.3 59.4 94.5 11.8 89.5 52.2 94.5 34.1
CDTrans(DeiT-B) 88.4 (model) 97.7 86.39 86.87 83.33 97.76 97.16 95.93 84.08 97.93 83.47 94.59 55.3

Table 4 [UDA results on DomainNet]

Base-S clp info pnt qdr rel skt Avg. CDTrans-S clp info pnt qdr rel skt Avg.
clp - 21.2 44.2 15.3 59.9 46.0 37.3 clp - 25.3 52.5 23.2 68.3 53.2 44.5
model model model model model model model
info 36.8 - 39.4 5.4 52.1 32.6 33.3 info 47.6 - 48.3 9.9 62.8 41.1 41.9
model model model model model model model
pnt 47.1 21.7 - 5.7 60.2 39.9 34.9 pnt 55.4 24.5 - 11.7 67.4 48.0 41.4
model model model model model model model
qdr 25.0 3.3 10.4 - 18.8 14.0 14.3 qdr 36.6 5.3 19.3 - 33.8 22.7 23.5
model model model model model model model
rel 54.8 23.9 52.6 7.4 - 40.1 35.8 rel 61.5 28.1 56.8 12.8 - 47.2 41.3
model model model model model model model
skt 55.6 18.6 42.7 14.9 55.7 - 37.5 skt 64.3 26.1 53.2 23.9 66.2 - 46.7
model model model model model model model
Avg. 43.9 17.7 37.9 9.7 49.3 34.5 32.2 Avg. 53.08 21.86 46.02 16.3 59.7 42.44 39.9
Base-B clp info pnt qdr rel skt Avg. CDTrans-B clp info pnt qdr rel skt Avg.
clp - 24.2 48.9 15.5 63.9 50.7 40.6 clp - 29.4 57.2 26.0 72.6 58.1 48.7
model model model model model model model
info 43.5 - 44.9 6.5 58.8 37.6 38.3 info 57.0 - 54.4 12.8 69.5 48.4 48.4
model model model model model model model
pnt 52.8 23.3 - 6.6 64.6 44.5 38.4 pnt 62.9 27.4 - 15.8 72.1 53.9 46.4
model model model model model model model
qdr 31.8 6.1 15.6 - 23.4 18.9 19.2 qdr 44.6 8.9 29.0 - 42.6 28.5 30.7
model model model model model model model
rel 58.9 26.3 56.7 9.1 - 45.0 39.2 rel 66.2 31.0 61.5 16.2 - 52.9 45.6
model model model model model model model
skt 60.0 21.1 48.4 16.6 61.7 - 41.6 skt 69.0 29.6 59.0 27.2 72.5 - 51.5
model model model model model model model
Avg. 49.4 20.2 42.9 10.9 54.5 39.3 36.2 Avg. 59.9 25.3 52.2 19.6 65.9 48.4 45.2

Requirements

Installation

pip install -r requirements.txt
(Python version is the 3.7 and the GPU is the V100 with cuda 10.1, cudatoolkit 10.1)

Prepare Datasets

Download the UDA datasets Office-31, Office-Home, VisDA-2017, DomainNet

Then unzip them and rename them under the directory like follow: (Note that each dataset floader needs to make sure that it contains the txt file that contain the path and lable of the picture, which is already in data/the_dataset of this project.)

data
├── OfficeHomeDataset
│   │── class_name
│   │   └── images
│   └── *.txt
├── domainnet
│   │── class_name
│   │   └── images
│   └── *.txt
├── office31
│   │── class_name
│   │   └── images
│   └── *.txt
├── visda
│   │── train
│   │   │── class_name
│   │   │   └── images
│   │   └── *.txt 
│   └── validation
│       │── class_name
│       │   └── images
│       └── *.txt 

Prepare DeiT-trained Models

For fair comparison in the pre-training data set, we use the DeiT parameter init our model based on ViT. You need to download the ImageNet pretrained transformer model : DeiT-Small, DeiT-Base and move them to the ./data/pretrainModel directory.

Training

We utilize 1 GPU for pre-training and 2 GPUs for UDA, each with 16G of memory.

Scripts.

Command input paradigm

bash scripts/[pretrain/uda]/[office31/officehome/visda/domainnet]/run_*.sh [deit_base/deit_small]

For example

DeiT-Base scripts

# Office-31     Source: Amazon   ->  Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_base
bash scripts/uda/office31/run_office_amazon.sh deit_base

#Office-Home    Source: Art      ->  Target: Clipart, Product, Real_World
bash scripts/pretrain/officehome/run_officehome_Ar.sh deit_base
bash scripts/uda/officehome/run_officehome_Ar.sh deit_base

# VisDA-2017    Source: train    ->  Target: validation
bash scripts/pretrain/visda/run_visda.sh deit_base
bash scripts/uda/visda/run_visda.sh deit_base

# DomainNet     Source: Clipart  ->  Target: painting, quickdraw, real, sketch, infograph
bash scripts/pretrain/domainnet/run_domainnet_clp.sh deit_base
bash scripts/uda/domainnet/run_domainnet_clp.sh deit_base

DeiT-Small scripts Replace deit_base with deit_small to run DeiT-Small results. An example of training on office-31 is as follows:

# Office-31     Source: Amazon   ->  Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_small
bash scripts/uda/office31/run_office_amazon.sh deit_small

Evaluation

# For example VisDA-2017
python test.py --config_file 'configs/uda.yml' MODEL.DEVICE_ID "('0')" TEST.WEIGHT "('../logs/uda/vit_base/visda/transformer_best_model.pth')" DATASETS.NAMES 'VisDA' DATASETS.NAMES2 'VisDA' OUTPUT_DIR '../logs/uda/vit_base/visda/' DATASETS.ROOT_TRAIN_DIR './data/visda/train/train_image_list.txt' DATASETS.ROOT_TRAIN_DIR2 './data/visda/train/train_image_list.txt' DATASETS.ROOT_TEST_DIR './data/visda/validation/valid_image_list.txt'  

Acknowledgement

Codebase from TransReID

Simulating an AI playing 2048 using the Expectimax algorithm

2048-expectimax Simulating an AI playing 2048 using the Expectimax algorithm The base game engine uses code from here. The AI player is modeled as a m

Subha Ramesh 2 Jan 31, 2022
Generating Radiology Reports via Memory-driven Transformer

R2Gen This is the implementation of Generating Radiology Reports via Memory-driven Transformer at EMNLP-2020. Citations If you use or extend our work,

CUHK-SZ NLP Group 101 Dec 13, 2022
Reference implementation for Deep Unsupervised Learning using Nonequilibrium Thermodynamics

Diffusion Probabilistic Models This repository provides a reference implementation of the method described in the paper: Deep Unsupervised Learning us

Jascha Sohl-Dickstein 238 Jan 02, 2023
Context Axial Reverse Attention Network for Small Medical Objects Segmentation

CaraNet: Context Axial Reverse Attention Network for Small Medical Objects Segmentation This repository contains the implementation of a novel attenti

401 Dec 23, 2022
code for paper"A High-precision Semantic Segmentation Method Combining Adversarial Learning and Attention Mechanism"

PyTorch implementation of UAGAN(U-net Attention Generative Adversarial Networks) This repository contains the source code for the paper "A High-precis

Tong 8 Apr 25, 2022
Deep Reinforcement Learning for Multiplayer Online Battle Arena

MOBA_RL Deep Reinforcement Learning for Multiplayer Online Battle Arena Prerequisite Python 3 gym-derk Tensorflow 2.4.1 Dotaservice of TimZaman Seed R

Dohyeong Kim 32 Dec 18, 2022
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021
Minecraft Hack Detection With Python

Minecraft Hack Detection An attempt to try and use crowd sourced replays to find

Kuleen Sasse 3 Mar 26, 2022
ByteTrack: Multi-Object Tracking by Associating Every Detection Box

ByteTrack ByteTrack is a simple, fast and strong multi-object tracker. ByteTrack: Multi-Object Tracking by Associating Every Detection Box Yifu Zhang,

Yifu Zhang 2.9k Jan 04, 2023
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

Kang Liao 18 Dec 23, 2022
Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation. In CVPR 2022.

Nonuniform-to-Uniform Quantization This repository contains the training code of N2UQ introduced in our CVPR 2022 paper: "Nonuniform-to-Uniform Quanti

Zechun Liu 60 Dec 28, 2022
Multiview 3D object detection on MultiviewC dataset through moft3d.

Voxelized 3D Feature Aggregation for Multiview Detection [arXiv] Multiview 3D object detection on MultiviewC dataset through VFA. Introduction We prop

Jiahao Ma 20 Dec 21, 2022
WebUAV-3M: A Benchmark Unveiling the Power of Million-Scale Deep UAV Tracking

WebUAV-3M: A Benchmark Unveiling the Power of Million-Scale Deep UAV Tracking [Paper Link] Abstract In this work, we contribute a new million-scale Un

25 Jan 01, 2023
Synthesizing and manipulating 2048x1024 images with conditional GANs

pix2pixHD Project | Youtube | Paper Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translatio

NVIDIA Corporation 6k Dec 27, 2022
This reporistory contains the test-dev data of the paper "xGQA: Cross-lingual Visual Question Answering".

This reporistory contains the test-dev data of the paper "xGQA: Cross-lingual Visual Question Answering".

AdapterHub 18 Dec 09, 2022
Introducing neural networks to predict stock prices

IntroNeuralNetworks in Python: A Template Project IntroNeuralNetworks is a project that introduces neural networks and illustrates an example of how o

Vivek Palaniappan 637 Jan 04, 2023
Cmsc11 arcade - Final Project for CMSC11

cmsc11_arcade Final Project for CMSC11 Developers: Limson, Mark Vincent Peñafiel

Gregory 1 Jan 18, 2022
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.

[ICLR 2021] RAPID: A Simple Approach for Exploration in Reinforcement Learning This is the Tensorflow implementation of ICLR 2021 paper Rank the Episo

Daochen Zha 48 Nov 21, 2022
PyTorch implementation of SIFT descriptor

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022