VACA: Designing Variational Graph Autoencoders for Interventional and Counterfactual Queries

Related tags

Deep LearningVACA
Overview

VACA

Code repository for the paper "VACA: Designing Variational Graph Autoencoders for Interventional and Counterfactual Queries (arXiv)". The implementation is based on Pytorch, Pytorch Geometric and Pytorch Lightning. The repository contains the necessary resources to run the experiments of the paper. Follow the instructions below to download the German dataset.

Installation

Create conda environment and activate it:

conda create --name vaca python=3.9 --no-default-packages
conda activate vaca 

Option 1: Import the conda environment

conda env create -f environment.yml

Option 2: Commands

conda install pip
pip install torch torchvision torchaudio
pip install pytorch-lightning
pip install -U scikit-learn
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
pip install matplotlib
pip install seaborn

Note: The German dataset is not contained in this repository. The first time you try to train on the German dataset, you will get an error with instructions on how to download and store it. Please follow the instructions, such that the code runs smoothly.

Datasets

This repository contains 7 different SCMs: - ColliderSCM - MGraphSCM - ChainSCM - TriangleSCM - LoanSCM - AdultSCM - GermanSCM

Additionally, we provide the implementation of the first five SCMs with three different types of structural equations: linear (LIN), non-linear (NLIN) and non-additive (NADD). You can find the implementation of all the datasets inside the folder datasets. To create all datasets at once run python _create_data_toy.py (this is optional since the datasets will be created as needed on the fly).

How to create your custom Toy Datasets

We also provide a function to create custom ToySCM datasets. Here is an example of an SCM with 2 nodes

from datasets.toy import create_toy_dataset
from utils.distributions import *
dataset = create_toy_dataset(root_dir='./my_custom_datasets',
                             name='2graph',
                             eq_type='linear',
                             nodes_to_intervene=['x1'],
                             structural_eq={'x1': lambda u1: u1,
                                            'x2': lambda u2, x1: u2 + x1},
                             noises_distr={'x1': Normal(0,1),
                                           'x2': Normal(0,1)},
                             adj_edges={'x1': ['x2'],
                                        'x2': []},
                             split='train',
                             num_samples=5000,
                             likelihood_names='d_d',
                             lambda_=0.05)

Training

To train a model you need to execute the script main.py. For that, you need to specify three configuration files: - dataset_file: Specifies the dataset and the parameters of the dataset. You can overwrite the dataset parameters -d. - model_file: Specifies the model and the parameters of the model as well as the optimizer. You can overwrite the model parameters with -m and the optimizer parameters with -o. - trainer_file: Specifies the training parameters of the Trainer object from PyTorch Lightning.

For plotting results use --plots 1. For more information, run python main.py --help.

Examples

To train our VACA algorithm on each of the synthetic graphs with linear structural equations (default value in dataset_ ):

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_vaca.yaml

You can also select a different SEM with the -d option and

  • for linear (LIN) equations -d equations_type=linear,
  • for non-linear (NLIN) equations -d equations_type=non-linear,
  • for non-additive (NADD) equation -d equations_type=non-additive.

For example, to train the triangle graph with non linear SEM:

python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_vaca.yaml -d equations_type=non-linear

We can train our VACA algorithm on the German dataset:

python main.py --dataset_file _params/dataset_german.yaml --model_file _params/model_vaca.yaml

To run the CAREFL model:

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_carefl.yaml

To run the MultiCVAE model:

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_mcvae.yaml

How to load a trained model?

To load a trained model:

  • set the training flag to -i 0.
  • select configuration file of our training model, i.e. hparams_full.yaml
python main.py --yaml_file=PATH/hparams_full.yaml -i 0

Load a model and train/evaluate counterfactual fairness

Load your model and add the flag --eval_fair. For example:

python main.py --yaml_file=PATH/hparams_full.yaml -i 0 --eval_fair --show_results

TensorBoard visualization

You can track different metrics during (and after) training using TensorBoard. For example, if the root folder of the experiments is exper_test, we can run the following command in a terminal

tensorboard --logdir exper_test/   

to display the logs of all experiments contained in such folder. Then, we go to our favourite browser and go to http://localhost:6006/ to visualize all the results.

Owner
Pablo Sánchez-Martín
Ph.D. student at Max Planck Institute for Intelligence Systems
Pablo Sánchez-Martín
Powerful and efficient Computer Vision Annotation Tool (CVAT)

Computer Vision Annotation Tool (CVAT) CVAT is free, online, interactive video and image annotation tool for computer vision. It is being used by our

OpenVINO Toolkit 8.6k Jan 01, 2023
Learning Representational Invariances for Data-Efficient Action Recognition

Learning Representational Invariances for Data-Efficient Action Recognition Official PyTorch implementation for Learning Representational Invariances

Virginia Tech Vision and Learning Lab 27 Nov 22, 2022
Python Tensorflow 2 scripts for detecting objects of any class in an image without knowing their label.

Tensorflow-Mobile-Generic-Object-Localizer Python Tensorflow 2 scripts for detecting objects of any class in an image without knowing their label. Ori

Ibai Gorordo 11 Nov 15, 2022
Python scripts for performing lane detection using the LSTR model in ONNX

ONNX LSTR Lane Detection Python scripts for performing lane detection using the Lane Shape Prediction with Transformers (LSTR) model in ONNX. Requirem

Ibai Gorordo 29 Aug 30, 2022
Code for 'Self-Guided and Cross-Guided Learning for Few-shot segmentation. (CVPR' 2021)'

SCL Introduction Code for 'Self-Guided and Cross-Guided Learning for Few-shot segmentation. (CVPR' 2021)' We evaluated our approach using two baseline

34 Oct 08, 2022
CLIPImageClassifier wraps clip image model from transformers

CLIPImageClassifier CLIPImageClassifier wraps clip image model from transformers. CLIPImageClassifier is initialized with the argument classes, these

Jina AI 6 Sep 12, 2022
Learning Skeletal Articulations with Neural Blend Shapes

This repository provides an end-to-end library for automatic character rigging and blend shapes generation as well as a visualization tool. It is based on our work Learning Skeletal Articulations wit

Peizhuo 504 Dec 30, 2022
Miscellaneous and lightweight network tools

Network Tools Collection of miscellaneous and lightweight network tools to simplify daily operations, administration, and troubleshooting of networks.

Nicholas Russo 22 Mar 22, 2022
The code written during my Bachelor Thesis "Classification of Human Whole-Body Motion using Hidden Markov Models".

This code was written during the course of my Bachelor thesis Classification of Human Whole-Body Motion using Hidden Markov Models. Some things might

Matthias Plappert 14 Dec 06, 2022
Local trajectory planner based on a multilayer graph framework for autonomous race vehicles.

Graph-Based Local Trajectory Planner The graph-based local trajectory planner is python-based and comes with open interfaces as well as debug, visuali

TUM - Institute of Automotive Technology 160 Jan 04, 2023
DrWhy is the collection of tools for eXplainable AI (XAI). It's based on shared principles and simple grammar for exploration, explanation and visualisation of predictive models.

Responsible Machine Learning With Great Power Comes Great Responsibility. Voltaire (well, maybe) How to develop machine learning models in a responsib

Model Oriented 590 Dec 26, 2022
Clairvoyance: a Unified, End-to-End AutoML Pipeline for Medical Time Series

Clairvoyance: A Pipeline Toolkit for Medical Time Series Authors: van der Schaar Lab This repository contains implementations of Clairvoyance: A Pipel

van_der_Schaar \LAB 89 Dec 07, 2022
A framework that allows people to write their own Rocket League bots.

YOU PROBABLY SHOULDN'T PULL THIS REPO Bot Makers Read This! If you just want to make a bot, you don't need to be here. Instead, start with one of thes

543 Dec 20, 2022
SmoothGrad implementation in PyTorch

SmoothGrad implementation in PyTorch PyTorch implementation of SmoothGrad: removing noise by adding noise. Vanilla Gradients SmoothGrad Guided backpro

SSKH 143 Jan 05, 2023
SNE-RoadSeg in PyTorch, ECCV 2020

SNE-RoadSeg Introduction This is the official PyTorch implementation of SNE-RoadSeg: Incorporating Surface Normal Information into Semantic Segmentati

242 Dec 20, 2022
Deconfounding Temporal Autoencoder: Estimating Treatment Effects over Time Using Noisy Proxies

Deconfounding Temporal Autoencoder (DTA) This is a repository for the paper "Deconfounding Temporal Autoencoder: Estimating Treatment Effects over Tim

Milan Kuzmanovic 3 Feb 04, 2022
Multi-layer convolutional LSTM with Pytorch

Convolution_LSTM_pytorch Thanks for your attention. I haven't got time to maintain this repo for a long time. I recommend this repo which provides an

Zijie Zhuang 734 Jan 03, 2023
Medical-Image-Triage-and-Classification-System-Based-on-COVID-19-CT-and-X-ray-Scan-Dataset

Medical-Image-Triage-and-Classification-System-Based-on-COVID-19-CT-and-X-ray-Sc

2 Dec 26, 2021
YOLOv4 / Scaled-YOLOv4 / YOLO - Neural Networks for Object Detection (Windows and Linux version of Darknet )

Yolo v4, v3 and v2 for Windows and Linux (neural networks for object detection) Paper YOLO v4: https://arxiv.org/abs/2004.10934 Paper Scaled YOLO v4:

Alexey 20.2k Jan 09, 2023
Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic video-to-video translation.

vid2vid Project | YouTube(short) | YouTube(full) | arXiv | Paper(full) Pytorch implementation for high-resolution (e.g., 2048x1024) photorealistic vid

NVIDIA Corporation 8.1k Jan 01, 2023