Implementation of our recent paper, WOOD: Wasserstein-based Out-of-Distribution Detection.

Related tags

Deep LearningWOOD
Overview

WOOD

Implementation of our recent paper, WOOD: Wasserstein-based Out-of-Distribution Detection.

Abstract

The training and test data for deep-neural-network-based classifiers are usually assumed to be sampled from the same distribution. When part of the test samples are drawn from a distribution that is sufficiently far away from that of the training samples (a.k.a. out-of-distribution (OOD) samples), the trained neural network has a tendency to make high confidence predictions for these OOD samples. Detection of the OOD samples is critical when training a neural network used for image classification, object detection, etc. It can enhance the classifier's robustness to irrelevant inputs, and improve the system resilience and security under different forms of attacks. Detection of OOD samples has three main challenges: (i) the proposed OOD detection method should be compatible with various architectures of classifiers (e.g., DenseNet, ResNet), without significantly increasing the model complexity and requirements on computational resources; (ii) the OOD samples may come from multiple distributions, whose class labels are commonly unavailable; (iii) a score function needs to be defined to effectively separate OOD samples from in-distribution (InD) samples. To overcome these challenges, we propose a Wasserstein-based out-of-distribution detection (WOOD) method. The basic idea is to define a Wasserstein-distance-based score that evaluates the dissimilarity between a test sample and the distribution of InD samples. An optimization problem is then formulated and solved based on the proposed score function. The statistical learning bound of the proposed method is investigated to guarantee that the loss value achieved by the empirical optimizer approximates the global optimum. The comparison study results demonstrate that the proposed WOOD consistently outperforms other existing OOD detection methods.

Citation

If you find our work useful in your research, please consider citing:

@misc{wang2021wood,
      title={WOOD: Wasserstein-based Out-of-Distribution Detection}, 
      author={Yinan Wang and Wenbo Sun and Jionghua "Judy" Jin and Zhenyu "James" Kong and Xiaowei Yue},
      year={2021},
      eprint={2112.06384},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Installation

The code has been tested on following environment

Ubuntu 18.04
python 3.6
CUDA 11.0
torch 1.4.0
scikit-learn 0.21.3
geomloss 0.2.3

Dataset

The experiments are conducted on MNIST, FashionMNIST, Cifar10, SVHN, and Tiny-ImageNet-200. The first four datasets can be automatically downloaded via PyTorch, the Tiny-ImageNet-200 needs to be manually downloaded and put the data files in the folder

Usage

WOOD

The performance of the proposed WOOD framework is tested using DenseNet as the backbone classifier.

CUDA_VISIBLE_DEVICES = ID  python main_OOD_binary.py [beta value] [number of epochs] [batch size] [InD batch size] [InD dataset] [OOD dataset] [Image channels]
CUDA_VISIBLE_DEVICES = ID  python main_OOD_dynamic.py [beta value] [number of epochs] [batch size] [InD batch size] [InD dataset] [OOD dataset] [Image channels]

e.g. CUDA_VISIBLE_DEVICES=0 python main_OOD_binary.py 0.1 60 60 50 Cifar10 Imagenet_c 3
     CUDA_VISIBLE_DEVICES=0 python main_OOD_dynamic.py 0.1 60 60 50 Cifar10 Imagenet_c 3

Note that the difference between main_OOD_binary.py and main_OOD_dynamic.py is the distance matrix used in the Wasserstein distance, which is discussed in our paper. The trained model is saved in directory. The model performance will be routinely tested during training.

Baseline Methods

The implementation of baseline methods is mainly based on the repo.

License

This project is licensed under the MIT License - see the LICENSE.md file for details.

Acknowledgments

The implementation of DenseNet is base on the repo.

The implementation of Wasserstein distance is mainly base on geomloss.

Code for “ACE-HGNN: Adaptive Curvature ExplorationHyperbolic Graph Neural Network”

ACE-HGNN: Adaptive Curvature Exploration Hyperbolic Graph Neural Network This repository is the implementation of ACE-HGNN in PyTorch. Environment pyt

9 Nov 28, 2022
PyTorch implementation of "Image-to-Image Translation Using Conditional Adversarial Networks".

pix2pix-pytorch PyTorch implementation of Image-to-Image Translation Using Conditional Adversarial Networks. Based on pix2pix by Phillip Isola et al.

mrzhu 383 Dec 17, 2022
Code for Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty

Deep Deterministic Uncertainty This repository contains the code for Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic

Jishnu Mukhoti 69 Nov 28, 2022
Multi-Glimpse Network With Python

Multi-Glimpse Network Our code requires Python ≥ 3.8 Installation For example, venv + pip: $ python3 -m venv env $ source env/bin/activate (env) $ pyt

9 May 10, 2022
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
A PyTorch implementation of Implicit Q-Learning

IQL-PyTorch This repository houses a minimal PyTorch implementation of Implicit Q-Learning (IQL), an offline reinforcement learning algorithm, along w

Garrett Thomas 30 Dec 12, 2022
Unofficial Implementation of RobustSTL: A Robust Seasonal-Trend Decomposition Algorithm for Long Time Series (AAAI 2019)

RobustSTL: A Robust Seasonal-Trend Decomposition Algorithm for Long Time Series (AAAI 2019) This repository contains python (3.5.2) implementation of

Doyup Lee 222 Dec 21, 2022
Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

UTNet (Accepted at MICCAI 2021) Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation Introduction Transf

110 Jan 01, 2023
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Semantic Segmentation.

Swin Transformer for Semantic Segmentation of satellite images This repo contains the supported code and configuration files to reproduce semantic seg

23 Oct 10, 2022
Nvdiffrast - Modular Primitives for High-Performance Differentiable Rendering

Nvdiffrast – Modular Primitives for High-Performance Differentiable Rendering Modular Primitives for High-Performance Differentiable Rendering Samuli

NVIDIA Research Projects 675 Jan 06, 2023
Semiconductor Machine learning project

Wafer Fault Detection Problem Statement: Wafer (In electronics), also called a slice or substrate, is a thin slice of semiconductor, such as a crystal

kunal suryawanshi 1 Jan 15, 2022
TensorFlow implementation for Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How

Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How TensorFlow implementation for Bayesian Modeling and Unce

Shen Lab at Texas A&M University 8 Sep 02, 2022
This repo provides a demo for the CVPR 2021 paper "A Fourier-based Framework for Domain Generalization" on the PACS dataset.

FACT This repo provides a demo for the CVPR 2021 paper "A Fourier-based Framework for Domain Generalization" on the PACS dataset. To cite, please use:

105 Dec 17, 2022
MetaDrive: Composing Diverse Scenarios for Generalizable Reinforcement Learning

MetaDrive: Composing Diverse Driving Scenarios for Generalizable RL [ Documentation | Demo Video ] MetaDrive is a driving simulator with the following

DeciForce: Crossroads of Machine Perception and Autonomy 276 Jan 04, 2023
Used to record WKU's utility bills on a regular basis.

WKU水电费小助手 一个用于定期记录WKU水电费的脚本 Looking for English Readme? 背景 由于WKU校园内的水电账单系统时常存在扣费延迟的现象,而补扣的费用缺乏令人信服的证明。不少学生为费用摸不着头脑,但也没有申诉的依据。为了更好地掌握水电费使用情况,留下一手证据,我开源

2 Jul 21, 2022
Deep Sketch-guided Cartoon Video Inbetweening

Cartoon Video Inbetweening Paper | DOI | Video The source code of Deep Sketch-guided Cartoon Video Inbetweening by Xiaoyu Li, Bo Zhang, Jing Liao, Ped

Xiaoyu Li 37 Dec 22, 2022
Madanalysis5 - A package for event file analysis and recasting of LHC results

Welcome to MadAnalysis 5 Outline What is MadAnalysis 5? Requirements Downloading

MadAnalysis 15 Jan 01, 2023
Repository for code and dataset for our EMNLP 2021 paper - “So You Think You’re Funny?”: Rating the Humour Quotient in Standup Comedy.

AI-OpenMic Dataset The dataset is available for download via the follwing link. Repository for code and dataset for our EMNLP 2021 paper - “So You Thi

6 Oct 26, 2022
Bringing sanity to world of messed-up data

Sanitize sanitize is a Python module for making sure various things (e.g. HTML) are safe to use. It was originally written by Mark Pilgrim and is dist

Alireza Savand 63 Oct 26, 2021
A curated list of automated deep learning (including neural architecture search and hyper-parameter optimization) resources.

Awesome AutoDL A curated list of automated deep learning related resources. Inspired by awesome-deep-vision, awesome-adversarial-machine-learning, awe

D-X-Y 2k Dec 30, 2022