Cross-Task Consistency Learning Framework for Multi-Task Learning

Related tags

Deep Learningxtask_mt
Overview

Cross-Task Consistency Learning Framework for Multi-Task Learning

Tested on

  • numpy(v1.19.1)
  • opencv-python(v4.4.0.42)
  • torch(v1.7.0)
  • torchvision(v0.8.0)
  • tqdm(v4.48.2)
  • matplotlib(v3.3.1)
  • seaborn(v0.11.0)
  • pandas(v.1.1.2)

Data

Cityscapes (CS)

Download Cityscapes dataset and put it in a subdirectory named ./data/cityscapes. The folder should have the following subfolders:

  • RGB image in folder leftImg8bit
  • Segmentation in folder gtFine
  • Disparity maps in folder disparity

NYU

We use the preprocessed NYUv2 dataset provided by this repo. Download the dataset and put it in the dataset folder in ./data/nyu.

Model

The model consists of one encoder (ResNet) and two decoders, one for each task. The decoders outputs the predictions for each task ("direct predictions"), which are fed to the TaskTransferNet.
The objective of the TaskTranferNet is to predict the other task given a prediction image as an input (Segmentation prediction -> Depth prediction, vice versa), which I refer to as "transferred predictions"

Loss function

When computing the losses, the direct predictions are compared with the target while the transferred predictions are compared with the direct predictions so that they "align themselves".
The total loss consists of 4 different losses:

  • direct segmentation loss: CrossEntropyLoss()
  • direct depth loss: L1() or MSE() or logL1() or SmoothL1()
  • transferred segmentation loss:
    CrossEntropyLoss() or KLDivergence()
  • transferred depth loss: L1() or SSIM()

* Label smoothing: To "smooth" the one-hot probability by taking some of the probability from the correct class and distributing it among other classes.
* SSIM: Structural Similarity Loss

Flags

The flags are the same for both datasets. The flags and its usage are as written below,

Flag Name Usage Comments
input_path Path to dataset default is data/cityscapes (CS) or data/nyu (NYU)
height height of prediction default: 128 (CS) or 288 (NYU)
width width of prediction default: 256 (CS) or 384 (NYU)
epochs # of epochs default: 250 (CS) or 100 (NYU)
enc_layers which encoder to use default: 34, can choose from 18, 34, 50, 101, 152
use_pretrain toggle on to use pretrained encoder weights available for both datasets
batch_size batch size default: 8 (CS) or 6 (NYU)
scheduler_step_size step size for scheduler default: 80 (CS) or 60 (NYU), note that we use StepLR
scheduler_gamma decay rate of scheduler default: 0.5
alpha weight of adding transferred depth loss default: 0.01 (CS) or 0.0001 (NYU)
gamma weight of adding transferred segmentation loss default: 0.01 (CS) or 0.0001 (NYU)
label_smoothing amount of label smoothing default: 0.0
lp loss fn for direct depth loss default: L1, can choose from L1, MSE, logL1, smoothL1
tdep_loss loss fn for transferred depth loss default: L1, can choose from L1 or SSIM
tseg_loss loss fn for transferred segmentation loss default: cross, can choose from cross or kl
batch_norm toggle to enable batch normalization layer in TaskTransferNet slightly improves segmentation task
wider_ttnet toggle to double the # of channels in TaskTransferNet
uncertainty_weights toggle to use uncertainty weights (Kendall, et al. 2018) we used this for best results
gradnorm toggle to use GradNorm (Chen, et al. 2018)

Training

Cityscapes

For the Cityscapes dataset, there are two versions of segmentation task, which are 7-classes task and 19-classes task (Use flag 'num_classes' to switch tasks, default is 7).
So far, the results show near-SOTA for 7-class segmentation task + depth estimation.

ResNet34 was used as the encoder, L1() for direct depth loss and CrossEntropyLoss() for transferred segmentation loss.
The hyperparameter weights for both transferred predictions were 0.01.
I used Adam as my optimizer with an initial learning rate of 0.0001 and trained for 250 epochs with batch size 8. The learning rate was halved every 80 epochs.

To reproduce the code, use the following:

python main_cross_cs.py --uncertainty_weights

NYU

Our results show SOTA for NYU dataset.

ResNet34 was used as the encoder, L1() for direct depth loss and CrossEntropyLoss() for transferred segmentation loss.
The hyperparameter weights for both transferred predictions were 0.0001.
I used Adam as my optimizer with an initial learning rate of 0.0001 and trained for 100 epochs with batch size 6. The learning rate was halved every 60 epochs.

To reproduce the code, use the following:

python main_cross_nyu.py --uncertainty_weights

Comparisons

Evaluation metrics are the following:

Segmentation

  • Pixel accuracy (Pix Acc): percentage of pixels with the correct label
  • mIoU: mean Intersection over Union

Depth

  • Absolute Error (Abs)
  • Absolute Relative Error (Abs Rel): Absolute error divided by ground truth depth

The results are the following:

Cityscapes

Models mIoU Pix Acc Abs Abs Rel
MTAN 53.04 91.11 0.0144 33.63
KD4MTL 52.71 91.54 0.0139 27.33
PCGrad 53.59 91.45 0.0171 31.34
AdaMT-Net 62.53 94.16 0.0125 22.23
Ours 66.51 93.56 0.0122 19.40

NYU

Models mIoU Pix Acc Abs Abs Rel
MTAN* 21.07 55.70 0.6035 0.2472
MTAN† 20.10 53.73 0.6417 0.2758
KD4MTL* 20.75 57.90 0.5816 0.2445
KD4MTL† 22.44 57.32 0.6003 0.2601
PCGrad* 20.17 56.65 0.5904 0.2467
PCGrad† 21.29 54.07 0.6705 0.3000
AdaMT-Net* 21.86 60.35 0.5933 0.2456
AdaMT-Net† 20.61 58.91 0.6136 0.2547
Ours† 30.31 63.02 0.5954 0.2235

*: Trained on 3 tasks (segmentation, depth, and surface normal)
†: Trained on 2 tasks (segmentation and depth)
Italic: Reproduced by ourselves

Scores with models trained on 3 tasks for NYU dataset are shown only as reference.

Papers referred

MTAN: [paper][github]
KD4MTL: [paper][github]
PCGrad: [paper][github (tensorflow)][github (pytorch)]
AdaMT-Net: [paper]

Owner
Aki Nakano
Student at the University of Tokyo pursuing master's degree. Joined UC Berkeley Summer Session 2019. Researching deep learning. Python/R
Aki Nakano
Awesome-AI-books - Some awesome AI related books and pdfs for learning and downloading

Awesome AI books Some awesome AI related books and pdfs for downloading and learning. Preface This repo only used for learning, do not use in business

luckyzhou 1k Jan 01, 2023
Course materials for Fall 2021 "CIS6930 Topics in Computing for Data Science" at New College of Florida

Fall 2021 CIS6930 Topics in Computing for Data Science This repository hosts course materials used for a 13-week course "CIS6930 Topics in Computing f

Yoshi Suhara 101 Nov 30, 2022
Convert openmmlab (not only mmdetection) series model to tensorrt

MMDet to TensorRT This project aims to convert the mmdetection model to TensorRT model end2end. Focus on object detection for now. Mask support is exp

JinTian 4 Dec 17, 2021
A modular application for performing anomaly detection in networks

Deep-Learning-Models-for-Network-Annomaly-Detection The modular app consists for mainly three annomaly detection algorithms. The system supports model

Shivam Patel 1 Dec 09, 2021
bespoke tooling for offensive security's Windows Usermode Exploit Dev course (OSED)

osed-scripts bespoke tooling for offensive security's Windows Usermode Exploit Dev course (OSED) Table of Contents Standalone Scripts egghunter.py fin

epi 268 Jan 05, 2023
This is the repo for our work "Towards Persona-Based Empathetic Conversational Models" (EMNLP 2020)

Towards Persona-Based Empathetic Conversational Models (PEC) This is the repo for our work "Towards Persona-Based Empathetic Conversational Models" (E

Zhong Peixiang 35 Nov 17, 2022
A distributed deep learning framework that supports flexible parallelization strategies.

FlexFlow FlexFlow is a deep learning framework that accelerates distributed DNN training by automatically searching for efficient parallelization stra

528 Dec 25, 2022
masscan + nmap + Finger

说明 个人根据使用习惯修改masnmap而来的一个小工具。调用masscan做全端口扫描,再调用nmap做服务识别,最后调用Finger做Web指纹识别。工具使用场景适合风险探测排查、众测等。 使用方法 安装依赖 pip3 install -r requirements.txt -i https:/

Ryan 3 Mar 25, 2022
Official repository for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'21, Oral Presentation)

Official PyTorch Implementation for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'2021, Oral Presentation) HOTR: End-to-

Kakao Brain 114 Nov 28, 2022
Official implementation of "SinIR: Efficient General Image Manipulation with Single Image Reconstruction" (ICML 2021)

SinIR (Official Implementation) Requirements To install requirements: pip install -r requirements.txt We used Python 3.7.4 and f-strings which are in

47 Oct 11, 2022
Lua-parser-lark - An out-of-box Lua parser written in Lark

An out-of-box Lua parser written in Lark Such parser handles a relaxed version o

Taine Zhao 2 Jul 19, 2022
Gender Classification Machine Learning Model using Sk-learn in Python with 97%+ accuracy and deployment

Gender-classification This is a ML model to classify Male and Females using some physical characterstics Data. Python Libraries like Pandas,Numpy and

Aryan raj 11 Oct 16, 2022
Analysing poker data from home games with friends

Poker Game Analysis Analysing poker data from home games with friends. Not a lot of data is collected, so this project is primarily focussed on descri

Stavros Karmaniolos 1 Oct 15, 2022
YOLOv7 - Framework Beyond Detection

🔥🔥🔥🔥 YOLO with Transformers and Instance Segmentation, with TensorRT acceleration! 🔥🔥🔥

JinTian 3k Jan 01, 2023
Code for ViTAS_Vision Transformer Architecture Search

Vision Transformer Architecture Search This repository open source the code for ViTAS: Vision Transformer Architecture Search. ViTAS aims to search fo

46 Dec 17, 2022
Repositório criado para abrigar os notebooks com a listas de exercícios propostos pelo professor Gustavo Guanabara do canal Curso em Vídeo do YouTube durante o Curso de Python 3

Curso em Vídeo - Exercícios de Python 3 Sobre o repositório Este repositório contém os notebooks com a listas de exercícios propostos pelo professor G

João Pedro Pereira 9 Oct 15, 2022
Neural network for recognizing the gender of people in photos

Neural Network For Gender Recognition How to test it? Install requirements.txt file using pip install -r requirements.txt command Run nn.py using pyth

Valery Chapman 1 Sep 18, 2022
Code accompanying "Evolving spiking neuron cellular automata and networks to emulate in vitro neuronal activity," accepted to IEEE SSCI ICES 2021

Evolving-spiking-neuron-cellular-automata-and-networks-to-emulate-in-vitro-neuronal-activity Code accompanying "Evolving spiking neuron cellular autom

SOCRATES: Self-Organizing Computational substRATES 2 Dec 02, 2022
MPLP: Metapath-Based Label Propagation for Heterogenous Graphs

MPLP: Metapath-Based Label Propagation for Heterogenous Graphs Results on MAG240M Here, we demonstrate the following performance on the MAG240M datase

Qiuying Peng 10 Jun 28, 2022
AutoPentest-DRL: Automated Penetration Testing Using Deep Reinforcement Learning

AutoPentest-DRL: Automated Penetration Testing Using Deep Reinforcement Learning AutoPentest-DRL is an automated penetration testing framework based o

Cyber Range Organization and Design Chair 217 Jan 01, 2023