Graph Posterior Network: Bayesian Predictive Uncertainty for Node Classification (NeurIPS 2021)

Overview

Graph Posterior Network

This is the official code repository to the paper

Graph Posterior Network: Bayesian Predictive Uncertainty for Node Classification
Maximilian Stadler, Bertrand Charpentier, Simon Geisler, Daniel Zügner, Stephan Günnemann
Conference on Neural Information Processing Systems (NeurIPS) 2021.

[Paper]|Video - coming soon]

Diagram

Installation

We recommend running this code with its dependencies in a conda enviroment. To begin with, create a new conda environment with all the necessary dependencies assuming that you are in the root directory of this project:

conda env create -f gpn_environment.yml python==3.8 --force

Since the code is packaged, you will also have to setup the code accordingly. Assuming that you are in the root directory of this project, run:

conda activate gpn
pip3 install -e .

Data

Since we rely on published datasets from the Torch-Geometric package, you don't have to download datasets manually. When you run experiments on supported datasets, those will be downloaded and placed in the corresponding data directories. You can run the following datasets

  • CoraML
  • CiteSeer
  • PubMed
  • AmazonPhotos
  • AmazonComputers
  • CoauthorCS
  • CoauthorPhysics

Running Experiments

The experimental setup builds upon Sacred and configuring experiments in .yamlfiles. We will provide configurations

  • for vanilla node classification
  • leave-out-class experiments
  • experiments with isolated node perturbations
  • experiments for feature shifts
  • experiments for edge shifts

with a default fraction of perturbed nodes of 10%. We provide them for the smaller datasets (i.e. all except ogbn-arxiv) for hidden dimensions H=10 and H=16.

The main experimental script is train_and_eval.py. Assuming that you are in the root directory of this project for all further commands, you can run experiments with

Vanilla Node Classification

For the vanilla classification on the CoraML dataset with a hidden dimension of 16 or 10 respectively, run

python3 train_and_eval.py with configs/gpn/classification_gpn_16.yaml data.dataset=CoraML
python3 train_and_eval.py with configs/gpn/classification_gpn_10.yaml data.dataset=CoraML

If you have GPU-devices availale on your system, experiments will run on device 0 on default. If no CUDA-devices can be found, the code will revert back to running only on CPUs. Runs will produce assets per default. Also note that for running experiments for graphs under perturbations, you will have to run the corresponding vanilla classification experiment first.

Options for Feature Shifts

We consider random features from Unit Gaussian Distribution (normal) and from a Bernoulli Distribution (bernoulli_0.5). When using the configuration ood_features, you can change those settings (key ood_perturbation_type) in the command line together with the fraction of perturbed nodes (key ood_budget_per_graph) or in the corresponding configurations files, for example as

python3 train_and_eval.py with configs/gpn/ood_features_gpn_16.yaml data.dataset=CoraML data.ood_perturbation_type=normal data.ood_budget_per_graph=0.025
python3 train_and_eval.py with configs/gpn/ood_features_gpn_16.yaml data.dataset=CoraML data.ood_perturbation_type=bernoulli_0.5 data.ood_budget_per_graph=0.025

For experiments considering perturbations in an isolated fashion, this applies accordingly but without the fraction of perturbed nodes, e.g.

python3 train_and_eval.py with configs/gpn/ood_isolated_gpn_16.yaml data.dataset=CoraML data.ood_perturbation_type=normal
python3 train_and_eval.py with configs/gpn/ood_isolated_gpn_16.yaml data.dataset=CoraML data.ood_perturbation_type=bernoulli_0.5

Options for Edge Shifts

We consider random edge perturbations and the global and untargeted DICE attack. Those attacks can be set with the key ood_type which can be either set to random_attack_dice or random_edge_perturbations. As above, those settings can be changed in the command line or in the corresponding configuration files. While the key ood_budget_per_graph refers to the fraction of perturbed nodes in the paragraph above, it describes the fraction of perturbed edges in this case.

python3 train_and_eval.py with configs/gpn/ood_features_gpn_16.yaml data.dataset=CoraML data.ood_type=random_attack_dice data.ood_budget_per_graph=0.025
python3 train_and_eval.py with configs/gpn/ood_features_gpn_16.yaml data.dataset=CoraML data.ood_type=random_edge_perturbations data.ood_budget_per_graph=0.025

Further Options

With the settings above, you can reproduce our experimental results. If you want to change different architectural settings, simply change the corresponding keys in the configuration files with most of them being self-explanatory.

Structure

If you want to have a detailed look at our code, we give a brief overview of our code structure.

  • configs: directory for model configurations
  • data: directory for datasets
  • gpn: source code
    • gpn.data: code related to loading datasets and creating ID and OOD datasets
    • gpn.distributions: code related to custom distributions similar to torch.distributions
    • experiments: main routines for running experiments, i.e. loading configs, setting up datasets and models, training and evaluation
    • gpn.layers: custom layers
    • gpn.models: implementation of reference models and Graph Posterior Network (+ablated models)
    • gpn.nn: training related utilities like losses, metrics, or training engines
    • gpn.utils: general utility code
  • saved_experiments: directory for saved models
  • train_and_eval.py: main script for training & evaluation
  • gpn_qualitative_evaluation.ipynb: jupyter notebook which evaluates the results from Graph Posterior Network in a qualitative fashion

Note that we provide the implementations of most of our used reference models. Our main Graph Posterior Network model can be found in gpn.models.gpn_base.py. Ablated models can be found in a similar fashion, i.e. PostNet in gpn.models.gpn_postnet.py, PostNet+diffusion in gpn.models.gpn_postnet_diff.py and the model diffusiong log-beta scores in gpn.models.gpn_log_beta.py.

We provide all basic configurations for reference models in configs/reference. Note that some models have dependencies with others, e.g. running classification_gcn_dropout.yaml or classification_gcn_energy.yaml would require training the underlying GCN first by running classification_gcn.yaml first, running classification_gcn_ensemble.yaml would require training 10 GCNs first with init_no in 1...10, and running classification_sgcn.yaml (GKDE-GCN) would require training the teacher-GCN first by running classification_gcn.yaml and computing the kernel values by running classification_gdk.yaml first.

Cite

Please cite our paper if you use the model or this code in your own work.

@incollection{graph-postnet,
title={Graph Posterior Network: Bayesian Predictive Uncertainty for Node Classification},
author={Stadler, Maximilian and Charpentier, Bertrand and Geisler, Simon and Z{\"u}gner, Daniel and G{\"u}nnemann, Stephan},
booktitle = {Advances in Neural Information Processing Systems},
volume = {34},
publisher = {Curran Associates, Inc.},
year = {2021}
}
Vision Transformer for 3D medical image registration (Pytorch).

ViT-V-Net: Vision Transformer for Volumetric Medical Image Registration keywords: vision transformer, convolutional neural networks, image registratio

Junyu Chen 192 Dec 20, 2022
HW3 ― GAN, ACGAN and UDA

HW3 ― GAN, ACGAN and UDA In this assignment, you are given datasets of human face and digit images. You will need to implement the models of both GAN

grassking100 1 Dec 13, 2021
[AAAI22] Reliable Propagation-Correction Modulation for Video Object Segmentation

Reliable Propagation-Correction Modulation for Video Object Segmentation (AAAI22) Preview version paper of this work is available at: https://arxiv.or

Xiaohao Xu 70 Dec 04, 2022
Training RNNs as Fast as CNNs

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 01, 2023
CVPRW 2021: How to calibrate your event camera

E2Calib: How to Calibrate Your Event Camera This repository contains code that implements video reconstruction from event data for calibration as desc

Robotics and Perception Group 104 Nov 16, 2022
[ICCV 2021] Released code for Causal Attention for Unbiased Visual Recognition

CaaM This repo contains the codes of training our CaaM on NICO/ImageNet9 dataset. Due to my recent limited bandwidth, this codebase is still messy, wh

Wang Tan 66 Dec 31, 2022
Official Pytorch implementation of "Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral)"

Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral): Official Project Webpage This repository provides the off

Kakao Enterprise Corp. 68 Dec 17, 2022
PyG (PyTorch Geometric) - A library built upon PyTorch to easily write and train Graph Neural Networks (GNNs)

PyG (PyTorch Geometric) is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.

PyG 16.5k Jan 08, 2023
《Deep Single Portrait Image Relighting》(ICCV 2019)

Ratio Image Based Rendering for Deep Single-Image Portrait Relighting [Project Page] This is part of the Deep Portrait Relighting project. If you find

62 Dec 21, 2022
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.

DEFT: Detection Embeddings for Tracking DEFT: Detection Embeddings for Tracking, Mohamed Chaabane, Peter Zhang, J. Ross Beveridge, Stephen O'Hara

Mohamed Chaabane 253 Dec 18, 2022
In this project we combine techniques from neural voice cloning and musical instrument synthesis to achieve good results from as little as 16 seconds of target data.

Neural Instrument Cloning In this project we combine techniques from neural voice cloning and musical instrument synthesis to achieve good results fro

Erland 127 Dec 23, 2022
An official source code for "Augmentation-Free Self-Supervised Learning on Graphs"

Augmentation-Free Self-Supervised Learning on Graphs An official source code for Augmentation-Free Self-Supervised Learning on Graphs paper, accepted

Namkyeong Lee 59 Dec 01, 2022
Pytorch implementation of SimSiam Architecture

SimSiam-pytorch A simple pytorch implementation of Exploring Simple Siamese Representation Learning which is developed by Facebook AI Research (FAIR)

Saeed Shurrab 1 Oct 20, 2021
A package related to building quasi-fibration symmetries

qf A package related to building quasi-fibration symmetries. If you'd like to learn more about how it works, see the brief explanation and References

Paolo Boldi 1 Dec 01, 2021
Saeed Lotfi 28 Dec 12, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intenti

NVIDIA Corporation 6.9k Jan 03, 2023
⚓ Eurybia monitor model drift over time and securize model deployment with data validation

View Demo · Documentation · Medium article 🔍 Overview Eurybia is a Python library which aims to help in : Detecting data drift and model drift Valida

MAIF 172 Dec 27, 2022
Analysis of Smiles through reservoir sampling & RDkit

Analysis of Smiles through reservoir sampling and machine learning (under development). This is a simple project that includes two Jupyter files for t

Aurimas A. Nausėdas 6 Aug 30, 2022
License Plate Detection Application

LicensePlate_Project 🚗 🚙 [Project] 2021.02 ~ 2021.09 License Plate Detection Application Overview 1. 데이터 수집 및 라벨링 차량 번호판 이미지를 직접 수집하여 각 이미지에 대해 '번호판

4 Oct 10, 2022
Official Implementation of LARGE: Latent-Based Regression through GAN Semantics

LARGE: Latent-Based Regression through GAN Semantics [Project Website] [Google Colab] [Paper] LARGE: Latent-Based Regression through GAN Semantics Yot

83 Dec 06, 2022