CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

Related tags

Deep LearningCDTrans
Overview

CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation [arxiv]

This is the official repository for CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

Introduction

Unsupervised domain adaptation (UDA) aims to transfer knowledge learned from a labeled source domain to a different unlabeled target domain. Most existing UDA methods focus on learning domain-invariant feature representation, either from the domain level or category level, using convolution neural networks (CNNs)-based frameworks. With the success of Transformer in various tasks, we find that the cross-attention in Transformer is robust to the noisy input pairs for better feature alignment, thus in this paper Transformer is adopted for the challenging UDA task. Specifically, to generate accurate input pairs, we design a two-way center-aware labeling algorithm to produce pseudo labels for target samples. Along with the pseudo labels, a weight-sharing triple-branch transformer framework is proposed to apply self-attention and cross-attention for source/target feature learning and source-target domain alignment, respectively. Such design explicitly enforces the framework to learn discriminative domain-specific and domain-invariant representations simultaneously. The proposed method is dubbed CDTrans (cross-domain transformer), and it provides one of the first attempts to solve UDA tasks with a pure transformer solution. Extensive experiments show that our proposed method achieves the best performance on all public UDA datasets including Office-Home, Office-31, VisDA-2017, and DomainNet.

framework

Results

Table 1 [UDA results on Office-31]

Methods Avg. A->D A->W D->A D->W W->A W->D
Baseline(DeiT-S) 86.7 87.6 86.9 74.9 97.7 73.5 99.6
model model model
CDTrans(DeiT-S) 90.4 94.6 93.5 78.4 98.2 78 99.6
model model model model model model
Baseline(DeiT-B) 88.8 90.8 90.4 76.8 98.2 76.4 100
model model model
CDTrans(DeiT-B) 92.6 97 96.7 81.1 99 81.9 100
model model model model model model

Table 2 [UDA results on Office-Home]

Methods Avg. Ar->Cl Ar->Pr Ar->Re Cl->Ar Cl->Pr Cl->Re Pr->Ar Pr->Cl Pr->Re Re->Ar Re->Cl Re->Pr
Baseline(DeiT-S) 69.8 55.6 73 79.4 70.6 72.9 76.3 67.5 51 81 74.5 53.2 82.7
model model model model
CDTrans(DeiT-S) 74.7 60.6 79.5 82.4 75.6 81.0 82.3 72.5 56.7 84.4 77.0 59.1 85.5
model model model model model model model model model model model model
Baseline(DeiT-B) 74.8 61.8 79.5 84.3 75.4 78.8 81.2 72.8 55.7 84.4 78.3 59.3 86
model model model model
CDTrans(DeiT-B) 80.5 68.8 85 86.9 81.5 87.1 87.3 79.6 63.3 88.2 82 66 90.6
model model model model model model model model model model model model

Table 3 [UDA results on VisDA-2017]

Methods Per-class plane bcycl bus car horse knife mcycl person plant sktbrd train truck
Baseline(DeiT-B) 67.3 (model) 98.1 48.1 84.6 65.2 76.3 59.4 94.5 11.8 89.5 52.2 94.5 34.1
CDTrans(DeiT-B) 88.4 (model) 97.7 86.39 86.87 83.33 97.76 97.16 95.93 84.08 97.93 83.47 94.59 55.3

Table 4 [UDA results on DomainNet]

Base-S clp info pnt qdr rel skt Avg. CDTrans-S clp info pnt qdr rel skt Avg.
clp - 21.2 44.2 15.3 59.9 46.0 37.3 clp - 25.3 52.5 23.2 68.3 53.2 44.5
model model model model model model model
info 36.8 - 39.4 5.4 52.1 32.6 33.3 info 47.6 - 48.3 9.9 62.8 41.1 41.9
model model model model model model model
pnt 47.1 21.7 - 5.7 60.2 39.9 34.9 pnt 55.4 24.5 - 11.7 67.4 48.0 41.4
model model model model model model model
qdr 25.0 3.3 10.4 - 18.8 14.0 14.3 qdr 36.6 5.3 19.3 - 33.8 22.7 23.5
model model model model model model model
rel 54.8 23.9 52.6 7.4 - 40.1 35.8 rel 61.5 28.1 56.8 12.8 - 47.2 41.3
model model model model model model model
skt 55.6 18.6 42.7 14.9 55.7 - 37.5 skt 64.3 26.1 53.2 23.9 66.2 - 46.7
model model model model model model model
Avg. 43.9 17.7 37.9 9.7 49.3 34.5 32.2 Avg. 53.08 21.86 46.02 16.3 59.7 42.44 39.9
Base-B clp info pnt qdr rel skt Avg. CDTrans-B clp info pnt qdr rel skt Avg.
clp - 24.2 48.9 15.5 63.9 50.7 40.6 clp - 29.4 57.2 26.0 72.6 58.1 48.7
model model model model model model model
info 43.5 - 44.9 6.5 58.8 37.6 38.3 info 57.0 - 54.4 12.8 69.5 48.4 48.4
model model model model model model model
pnt 52.8 23.3 - 6.6 64.6 44.5 38.4 pnt 62.9 27.4 - 15.8 72.1 53.9 46.4
model model model model model model model
qdr 31.8 6.1 15.6 - 23.4 18.9 19.2 qdr 44.6 8.9 29.0 - 42.6 28.5 30.7
model model model model model model model
rel 58.9 26.3 56.7 9.1 - 45.0 39.2 rel 66.2 31.0 61.5 16.2 - 52.9 45.6
model model model model model model model
skt 60.0 21.1 48.4 16.6 61.7 - 41.6 skt 69.0 29.6 59.0 27.2 72.5 - 51.5
model model model model model model model
Avg. 49.4 20.2 42.9 10.9 54.5 39.3 36.2 Avg. 59.9 25.3 52.2 19.6 65.9 48.4 45.2

Requirements

Installation

pip install -r requirements.txt
(Python version is the 3.7 and the GPU is the V100 with cuda 10.1, cudatoolkit 10.1)

Prepare Datasets

Download the UDA datasets Office-31, Office-Home, VisDA-2017, DomainNet

Then unzip them and rename them under the directory like follow: (Note that each dataset floader needs to make sure that it contains the txt file that contain the path and lable of the picture, which is already in data/the_dataset of this project.)

data
├── OfficeHomeDataset
│   │── class_name
│   │   └── images
│   └── *.txt
├── domainnet
│   │── class_name
│   │   └── images
│   └── *.txt
├── office31
│   │── class_name
│   │   └── images
│   └── *.txt
├── visda
│   │── train
│   │   │── class_name
│   │   │   └── images
│   │   └── *.txt 
│   └── validation
│       │── class_name
│       │   └── images
│       └── *.txt 

Prepare DeiT-trained Models

For fair comparison in the pre-training data set, we use the DeiT parameter init our model based on ViT. You need to download the ImageNet pretrained transformer model : DeiT-Small, DeiT-Base and move them to the ./data/pretrainModel directory.

Training

We utilize 1 GPU for pre-training and 2 GPUs for UDA, each with 16G of memory.

Scripts.

Command input paradigm

bash scripts/[pretrain/uda]/[office31/officehome/visda/domainnet]/run_*.sh [deit_base/deit_small]

For example

DeiT-Base scripts

# Office-31     Source: Amazon   ->  Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_base
bash scripts/uda/office31/run_office_amazon.sh deit_base

#Office-Home    Source: Art      ->  Target: Clipart, Product, Real_World
bash scripts/pretrain/officehome/run_officehome_Ar.sh deit_base
bash scripts/uda/officehome/run_officehome_Ar.sh deit_base

# VisDA-2017    Source: train    ->  Target: validation
bash scripts/pretrain/visda/run_visda.sh deit_base
bash scripts/uda/visda/run_visda.sh deit_base

# DomainNet     Source: Clipart  ->  Target: painting, quickdraw, real, sketch, infograph
bash scripts/pretrain/domainnet/run_domainnet_clp.sh deit_base
bash scripts/uda/domainnet/run_domainnet_clp.sh deit_base

DeiT-Small scripts Replace deit_base with deit_small to run DeiT-Small results. An example of training on office-31 is as follows:

# Office-31     Source: Amazon   ->  Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_small
bash scripts/uda/office31/run_office_amazon.sh deit_small

Evaluation

# For example VisDA-2017
python test.py --config_file 'configs/uda.yml' MODEL.DEVICE_ID "('0')" TEST.WEIGHT "('../logs/uda/vit_base/visda/transformer_best_model.pth')" DATASETS.NAMES 'VisDA' DATASETS.NAMES2 'VisDA' OUTPUT_DIR '../logs/uda/vit_base/visda/' DATASETS.ROOT_TRAIN_DIR './data/visda/train/train_image_list.txt' DATASETS.ROOT_TRAIN_DIR2 './data/visda/train/train_image_list.txt' DATASETS.ROOT_TEST_DIR './data/visda/validation/valid_image_list.txt'  

Acknowledgement

Codebase from TransReID

PyTorch implementation for the paper Pseudo Numerical Methods for Diffusion Models on Manifolds

Pseudo Numerical Methods for Diffusion Models on Manifolds (PNDM) This repo is the official PyTorch implementation for the paper Pseudo Numerical Meth

Luping Liu (刘路平) 196 Jan 05, 2023
Code repo for realtime multi-person pose estimation in CVPR'17 (Oral)

Realtime Multi-Person Pose Estimation By Zhe Cao, Tomas Simon, Shih-En Wei, Yaser Sheikh. Introduction Code repo for winning 2016 MSCOCO Keypoints Cha

Zhe Cao 4.9k Dec 31, 2022
InsTrim: Lightweight Instrumentation for Coverage-guided Fuzzing

InsTrim The paper: InsTrim: Lightweight Instrumentation for Coverage-guided Fuzzing Build Prerequisite llvm-8.0-dev clang-8.0 cmake = 3.2 Make git cl

75 Dec 23, 2022
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
The authors' official PyTorch SigWGAN implementation

The authors' official PyTorch SigWGAN implementation This repository is the official implementation of [Sig-Wasserstein GANs for Time Series Generatio

9 Jun 16, 2022
Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs)

Why Spectral Normalization Stabilizes GANs: Analysis and Improvements [paper (NeurIPS 2021)] [paper (arXiv)] [code] Authors: Zinan Lin, Vyas Sekar, Gi

Zinan Lin 32 Dec 16, 2022
Benchmark library for high-dimensional HPO of black-box models based on Weighted Lasso regression

LassoBench LassoBench is a library for high-dimensional hyperparameter optimization benchmarks based on Weighted Lasso regression. Note: LassoBench is

Kenan Šehić 5 Mar 15, 2022
OoD Minimum Anomaly Score GAN - Code for the Paper 'OMASGAN: Out-of-Distribution Minimum Anomaly Score GAN for Sample Generation on the Boundary'

OMASGAN: Out-of-Distribution Minimum Anomaly Score GAN for Sample Generation on the Boundary Out-of-Distribution Minimum Anomaly Score GAN (OMASGAN) C

- 8 Sep 27, 2022
A Python package for time series augmentation

tsaug tsaug is a Python package for time series augmentation. It offers a set of augmentation methods for time series, as well as a simple API to conn

Arundo Analytics 278 Jan 01, 2023
Source code for PairNorm (ICLR 2020)

PairNorm Official pytorch source code for PairNorm paper (ICLR 2020) This code requires pytorch_geometric=1.3.2 usage For SGC, we use original PairNo

62 Dec 08, 2022
Official Pytorch implementation for "End2End Occluded Face Recognition by Masking Corrupted Features, TPAMI 2021"

End2End Occluded Face Recognition by Masking Corrupted Features This is the Pytorch implementation of our TPAMI 2021 paper End2End Occluded Face Recog

Haibo Qiu 25 Oct 31, 2022
Implementation of CVPR'21: RfD-Net: Point Scene Understanding by Semantic Instance Reconstruction

RfD-Net [Project Page] [Paper] [Video] RfD-Net: Point Scene Understanding by Semantic Instance Reconstruction Yinyu Nie, Ji Hou, Xiaoguang Han, Matthi

Yinyu Nie 162 Jan 06, 2023
Method for facial emotion recognition compitition of Xunfei and Datawhale .

人脸情绪识别挑战赛-第3名-W03KFgNOc-源代码、模型以及说明文档 队名:W03KFgNOc 排名:3 正确率: 0.75564 队员:yyMoming,xkwang,RichardoMu。 比赛链接:人脸情绪识别挑战赛 文章地址:link emotion 该项目分别训练八个模型并生成csv文

6 Oct 17, 2022
A tensorflow implementation of an HMM layer

tensorflow_hmm Tensorflow and numpy implementations of the HMM viterbi and forward/backward algorithms. See Keras example for an example of how to use

Zach Dwiel 283 Oct 19, 2022
Pytorch implementation for our ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering".

TRAnsformer Routing Networks (TRAR) This is an official implementation for ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visu

Ren Tianhe 49 Nov 10, 2022
FedML: A Research Library and Benchmark for Federated Machine Learning

FedML: A Research Library and Benchmark for Federated Machine Learning 📄 https://arxiv.org/abs/2007.13518 News 2021-02-01 (Award): #NeurIPS 2020# Fed

FedML-AI 2.3k Jan 08, 2023
CSE-519---Project - Job Title Analysis (Project for CSE 519 - Data Science Fundamentals)

A Multifaceted Approach to Job Title Analysis CSE 519 - Data Science Fundamentals Project Description Project consists of three parts: Salary Predicti

Jimit Dholakia 1 Jan 04, 2022
Rank1 Conversation Emotion Detection Task

Rank1-Conversation_Emotion_Detection_Task accuracy macro-f1 recall 0.826 0.7544 0.719 基于预训练模型和时序预测模型的对话情感探测任务 1 摘要 针对对话情感探测任务,本文将其分为文本分类和时间序列预测两个子任务,分

Yuchen Han 2 Nov 28, 2021
Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive Learning".

ERICA Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive L

THUNLP 75 Nov 02, 2022
Kaggle competition: Springleaf Marketing Response

PruebaEnel Prueba Kaggle-Springleaf-master Prueba Kaggle-Springleaf Kaggle competition: Springleaf Marketing Response Competencia de Kaggle: Marketing

1 Feb 09, 2022