Training a Resilient Q-Network against Observational Interference, Causal Inference Q-Networks

Overview

Obs-Causal-Q-Network

AAAI 2022 - Training a Resilient Q-Network against Observational Interference

Preprint | Slides | Colab Demo | PyTorch

Environment Setup

  • option 1 (from conda .yml under conda 10.2 and python 3.6)
conda env create -f obs-causal-q-conda.yml 
  • option 2 (from a clean python 3.6 and please follow the setup of UnityAgent 3D environment for Banana Navigator )
pip install torch torchvision torchaudio
pip install dowhy
pip install gym

1. Example of Training Causal Inference Q-Network (CIQ) on Cartpole

  • Run Causal Inference Q-Network Training (--network 1 for Treatment Inference Q-network)
python 0-cartpole-main.py --network 1
  • Causal Inference Q-Network Architecture

  • Output Logs
observation space: Box(4,)
action space: Discrete(2)
Timing Atk Ratio: 10%
Using CEQNetwork_1. Number of Params: 41872
 Interference Type: 1  Use baseline:  0 use CGM:  1
With:  10.42 % timing attack
Episode 0   Score: 48.00, Average Score: 48.00, Loss: 1.71
With:  0.0 % timing attack
Episode 20   Score: 15.00, Average Score: 18.71, Loss: 30.56
With:  3.57 % timing attack
Episode 40   Score: 28.00, Average Score: 19.83, Loss: 36.36
With:  8.5 % timing attack
Episode 60   Score: 200.00, Average Score: 43.65, Loss: 263.29
With:  9.0 % timing attack
Episode 80   Score: 200.00, Average Score: 103.53, Loss: 116.35
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 193.4
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 164.2
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 147.8
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 193.4
With:  9.5 % timing attack
Episode 100   Score: 200.00, Average Score: 163.20, Loss: 77.38
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 198.4
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 197.8
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 197.6
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 198.6
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 199.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 186.8
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0

Environment solved in 114 episodes!     Average Score: 195.55
Environment solved in 114 episodes!     Average Score: 195.55 +- 25.07
############# Basic Evaluate #############
Using CEQNetwork_1. Number of Params: 41872
Evaluate Score : 200.0
############# Noise Evaluate #############
Using CEQNetwork_1. Number of Params: 41872
Robust Score : 200.0

2. Example of Training a "Variational" Causal Inference Q-Network on Unity 3D Banana Navigator

  • Run Variational Causal Inference Q-Networks (VCIQs) Training (--network 3 for Causal Variational Inference)
python 1-banana-navigator-main.py --network 3
  • Variational Causal Inference Q-Network Architecture

  • Output Logs
'Academy' started successfully!
Unity Academy name: Academy
        Number of Brains: 1
        Number of External Brains : 1
        Lesson number : 0
        Reset Parameters :

Unity brain name: BananaBrain
        Number of Visual Observations (per agent): 0
        Vector Observation space type: continuous
        Vector Observation space size (per agent): 37
        Number of stacked Vector Observation: 1
        Vector Action space type: discrete
        Vector Action space size (per agent): 4
        Vector Action descriptions: , , , 
Timing Atk Ratio: 10%
Using CEVAE_QNetwork.
Unity Worker id: 10  T: 1  Use baseline:  0  CEVAE:  1
With:  9.67 % timing attack
Episode 0   Score: 0.00, Average Score: 0.00
With:  11.0 % timing attack
Episode 5   Score: 1.00, Average Score: 0.17
With:  11.33 % timing attack
Episode 10   Score: 0.00, Average Score: 0.36
With:  10.33 % timing attack
Episode 15   Score: 0.00, Average Score: 0.56
...
Episode 205   Score: 10.00, Average Score: 9.25
With:  9.33 % timing attack
Episode 210   Score: 9.00, Average Score: 9.70
With:  9.0 % timing attack
Episode 215   Score: 10.00, Average Score: 11.10
With:  8.33 % timing attack
Episode 220   Score: 14.00, Average Score: 10.85
With:  12.33 % timing attack
Episode 225   Score: 19.00, Average Score: 11.70
With:  11.0 % timing attack
Episode 230   Score: 18.00, Average Score: 12.10
With:  7.67 % timing attack
Episode 235   Score: 21.00, Average Score: 11.60
With:  9.67 % timing attack
Episode 240   Score: 16.00, Average Score: 12.05

Environment solved in 242 episodes!     Average Score: 12.50
Environment solved in 242 episodes!     Average Score: 12.50 +- 4.87
############# Basic Evaluate #############
Using CEVAE_QNetwork.
Evaluate Score : 12.6
############# Noise Evaluate #############
Using CEVAE_QNetwork.
Robust Score : 12.5

Reference

This fun work was initialzed when Danny and I first read the Causal Variational Model between 2018 to 2019 with the helps from Dr. Yi Ouyang and Dr. Pin-Yu Chen.

Please consider to reference the paper if you find this work helpful or relative to your research.

@article{yang2021causal,
  title={Causal Inference Q-Network: Toward Resilient Reinforcement Learning},
  author={Yang, Chao-Han Huck and Hung, I and Danny, Te and Ouyang, Yi and Chen, Pin-Yu},
  journal={arXiv preprint arXiv:2102.09677},
  year={2021}
}
Owner
Speech, Privacy, Robust RL, and Causal Inference.
Official PyTorch repo for JoJoGAN: One Shot Face Stylization

JoJoGAN: One Shot Face Stylization This is the PyTorch implementation of JoJoGAN: One Shot Face Stylization. Abstract: While there have been recent ad

1.3k Dec 29, 2022
Official repository for Jia, Raghunathan, Göksel, and Liang, "Certified Robustness to Adversarial Word Substitutions" (EMNLP 2019)

Certified Robustness to Adversarial Word Substitutions This is the official GitHub repository for the following paper: Certified Robustness to Adversa

Robin Jia 38 Oct 16, 2022
Model Agnostic Interpretability for Multiple Instance Learning

MIL Model Agnostic Interpretability This repo contains the code for "Model Agnostic Interpretability for Multiple Instance Learning". Overview Executa

Joe Early 10 Dec 17, 2022
CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Detection in Remote Sensing Images

CFC-Net This project hosts the official implementation for the paper: CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Dete

ming71 55 Dec 12, 2022
Predicting Student Attentiveness using OpenCV

Predicting-Student-Attentiveness-using-OpenCV The model will predict if a student is attentive or not through facial parameter received through the st

Johann Pinto 2 Aug 20, 2022
The official repository for BaMBNet

BaMBNet-Pytorch Paper

Junjun Jiang 18 Dec 04, 2022
Image reconstruction done with untrained neural networks.

PyTorch Deep Image Prior An implementation of image reconstruction methods from Deep Image Prior (Ulyanov et al., 2017) in PyTorch. The point of the p

Atiyo Ghosh 192 Nov 30, 2022
This is a TensorFlow implementation for C2-Rec

This is a TensorFlow implementation for C2-Rec We refer to the repo SASRec. Requirements requirement.txt Datasets This repo includes Amazon Beauty dat

7 Nov 14, 2022
Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Train longer, generalize better - Big batch training This is a code repository used to generate the results appearing in "Train longer, generalize bet

Elad Hoffer 145 Sep 16, 2022
A small library for creating and manipulating custom JAX Pytree classes

Treeo A small library for creating and manipulating custom JAX Pytree classes Light-weight: has no dependencies other than jax. Compatible: Treeo Tree

Cristian Garcia 58 Nov 23, 2022
PyTorch implementations for our SIGGRAPH 2021 paper: Editable Free-viewpoint Video Using a Layered Neural Representation.

st-nerf We provide PyTorch implementations for our paper: Editable Free-viewpoint Video Using a Layered Neural Representation SIGGRAPH 2021 Jiakai Zha

Diplodocus 258 Jan 02, 2023
3rd Place Solution of the Traffic4Cast Core Challenge @ NeurIPS 2021

3rd Place Solution of Traffic4Cast 2021 Core Challenge This is the code for our solution to the NeurIPS 2021 Traffic4Cast Core Challenge. Paper Our so

7 Jul 25, 2022
Implementation of ECCV20 paper: the devil is in classification: a simple framework for long-tail object detection and instance segmentation

Implementation of our ECCV 2020 paper The Devil is in Classification: A Simple Framework for Long-tail Instance Segmentation This repo contains code o

twang 98 Sep 17, 2022
Repository for Driving Style Recognition algorithms for Autonomous Vehicles

Driving Style Recognition Using Interval Type-2 Fuzzy Inference System and Multiple Experts Decision Making Created by Iago Pachêco Gomes at USP - ICM

Iago Gomes 9 Nov 28, 2022
Rust bindings for the C++ api of PyTorch.

tch-rs Rust bindings for the C++ api of PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorc

Laurent Mazare 2.3k Dec 30, 2022
fastgradio is a python library to quickly build and share gradio interfaces of your trained fastai models.

fastgradio is a python library to quickly build and share gradio interfaces of your trained fastai models.

Ali Abdalla 34 Jan 05, 2023
Details about the wide minima density hypothesis and metrics to compute width of a minima

wide-minima-density-hypothesis Details about the wide minima density hypothesis and metrics to compute width of a minima This repo presents the wide m

Nikhil Iyer 9 Dec 27, 2022
This initial strategy was developed specifically for larger pools and is based on taking a moving average and deriving Bollinger Bands to create a projected active liquidity range.

Gamma's Strategy One This initial strategy was developed specifically for larger pools and is based on taking a moving average and deriving Bollinger

Gamma Strategies 46 Dec 02, 2022