Attention over nodes in Graph Neural Networks using PyTorch (NeurIPS 2019)

Overview

Intro

This repository contains code to generate data and reproduce experiments from our NeurIPS 2019 paper:

Boris Knyazev, Graham W. Taylor, Mohamed R. Amer. Understanding Attention and Generalization in Graph Neural Networks.

See slides here.

An earlier short version of our paper was presented as a contributed talk at ICLR Workshop on Representation Learning on Graphs and Manifolds, 2019.

Update:

In the code for MNIST, the dist variable should have been squared to make it a Gaussian. All figures and results were generated without squaring it. I don't think it's very important in terms of results, but if you square it, sigma should be adjusted accordingly.

MNIST TRIANGLES

For MNIST from top to bottom rows:

  • input test images with additive Gaussian noise with standard deviation in the range from 0 to 1.4 with step 0.2
  • attention coefficients (alpha) predicted by the unsupervised model
  • attention coefficients (alpha) predicted by the supervised model
  • attention coefficients (alpha) predicted by our weakly-supervised model

For TRIANGLES from top to bottom rows:

  • on the left: input test graph (with 4-100 nodes) with ground truth attention coefficients, on the right: graph obtained by ground truth node pooling
  • on the left: input test graph (with 4-100 nodes) with unsupervised attention coefficients, on the right: graph obtained by unsupervised node pooling
  • on the left: input test graph (with 4-100 nodes) with supervised attention coefficients, on the right: graph obtained by supervised node pooling
  • on the left: input test graph (with 4-100 nodes) with weakly-supervised attention coefficients, on the right: graph obtained by weakly-supervised node pooling

Note that during training, our MNIST models have not encountered noisy images and our TRIANGLES models have not encountered graphs larger than with N=25 nodes.

Examples using PyTorch Geometric

COLORS and TRIANGLES datasets are now also available in the TU format, so that you can use a general TU datareader. See PyTorch Geometric examples for COLORS and TRIANGLES.

Example of evaluating a pretrained model on MNIST

For more examples, see MNIST_eval_models and TRIANGLES_eval_models.

# Download model checkpoint or 'git clone' this repo
import urllib.request
# Let's use the model with supervised attention (other models can be found in the Table below)
model_name = 'checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar'
model_url = 'https://github.com/bknyaz/graph_attention_pool/raw/master/checkpoints/%s' % model_name
model_path = 'checkpoints/%s' % model_name
urllib.request.urlretrieve(model_url, model_path)
# Load the model
import torch
from chebygin import ChebyGIN

state = torch.load(model_path)
args = state['args']
model = ChebyGIN(in_features=5, out_features=10, filters=args.filters, K=args.filter_scale,
                 n_hidden=args.n_hidden, aggregation=args.aggregation, dropout=args.dropout,
                 readout=args.readout, pool=args.pool, pool_arch=args.pool_arch)
model.load_state_dict(state['state_dict'])
model = model.eval()
# Load image using standard PyTorch Dataset
from torchvision import datasets
data = datasets.MNIST('./data', train=False, download=True)
images = (data.test_data.numpy() / 255.)
import numpy as np
img = images[0].astype(np.float32)  # 28x28 MNIST image
# Extract superpixels and create node features
import scipy.ndimage
from skimage.segmentation import slic
from scipy.spatial.distance import cdist

# The number (n_segments) of superpixels returned by SLIC is usually smaller than requested, so we request more
superpixels = slic(img, n_segments=95, compactness=0.25, multichannel=False)
sp_indices = np.unique(superpixels)
n_sp = len(sp_indices)  # should be 74 with these parameters of slic

sp_intensity = np.zeros((n_sp, 1), np.float32)
sp_coord = np.zeros((n_sp, 2), np.float32)  # row, col
for seg in sp_indices:
    mask = superpixels == seg
    sp_intensity[seg] = np.mean(img[mask])
    sp_coord[seg] = np.array(scipy.ndimage.measurements.center_of_mass(mask))

# The model is invariant to the order of nodes in a graph
# We can shuffle nodes and obtain exactly the same results
ind = np.random.permutation(n_sp)
sp_coord = sp_coord[ind]
sp_intensity = sp_intensity[ind]
# Create edges between nodes in the form of adjacency matrix
sp_coord = sp_coord / images.shape[1]
dist = cdist(sp_coord, sp_coord)  # distance between all pairs of nodes
sigma = 0.1 * np.pi  # width of a Guassian
A = np.exp(- dist / sigma ** 2)  # transform distance to spatial closeness
A[np.diag_indices_from(A)] = 0  # remove self-loops
A = torch.from_numpy(A).float().unsqueeze(0)
# Prepare an input to the model and process it
N_nodes = sp_intensity.shape[0]
mask = torch.ones(1, N_nodes, dtype=torch.uint8)

# mean and std computed for superpixel features in the training set
mn = torch.tensor([0.11225057, 0.11225057, 0.11225057, 0.44206527, 0.43950436]).view(1, 1, -1)
sd = torch.tensor([0.2721889,  0.2721889,  0.2721889,  0.2987583,  0.30080357]).view(1, 1, -1)

node_features = (torch.from_numpy(np.pad(np.concatenate((sp_intensity, sp_coord), axis=1),
                                         ((0, 0), (2, 0)), 'edge')).unsqueeze(0) - mn) / sd    

y, other_outputs = model([node_features, A, mask, None, {'N_nodes': torch.zeros(1, 1) + N_nodes}])
alpha = other_outputs['alpha'][0].data
  • y is a vector with 10 unnormalized class scores. To get a predicted label, we can use torch.argmax(y).

  • alpha is a vector of attention coefficients alpha for each node.

Tasks & Datasets

  1. We design two synthetic graph tasks, COLORS and TRIANGLES, in which we predict the number of green nodes and the number of triangles respectively.

  2. We also experiment with the MNIST image classification dataset, which we preprocess by extracting superpixels - a more natural way to feed images to a graph. We denote this dataset as MNIST-75sp.

  3. We validate our weakly-supervised approach on three common graph classification benchmarks: COLLAB, PROTEINS and D&D.

For COLORS, TRIANGLES and MNIST we know ground truth attention for nodes, which allows us to study graph neural networks with attention in depth.

Data generation

To generate all data using a single command: ./scripts/prepare_data.sh.

All generated/downloaded ata will be stored in the local ./data directory. It can take about 1 hour to prepare all data (see my log) and all data take about 2 GB.

Alternatively, you can generate data for each task as described below.

In case of any issues with running these scripts, data can be downloaded from here.

COLORS

To generate training, validation and test data for our Colors dataset with different dimensionalities:

for dim in 3 8 16 32; do python generate_data.py --dim $dim; done

MNIST-75sp

To generate training and test data for our MNIST-75sp dataset using 4 CPU threads:

for split in train test; do python extract_superpixels.py -s $split -t 4; done

Data visualization

Once datasets are generated or downloaded, you can use the following IPython notebooks to load and visualize data:

COLORS and TRIANGLES, MNIST and COLLAB, PROTEINS and D&D.

Pretrained ChebyGIN models

Generalization results on the test sets for three tasks. Other results are available in the paper.

Click on the result to download a trained model in the PyTorch format.

Model COLORS-Test-LargeC TRIANGLES-Test-Large MNIST-75sp-Test-Noisy
Script to train models colors.sh triangles.sh mnist_75sp.sh
Global pooling 15 ± 7 30 ± 1 80 ± 12
Unsupervised attention 11 ± 6 26 ± 2 80 ± 23
Supervised attention 75 ± 17 48 ± 1 92.3 ± 0.4
Weakly-supervised attention 73 ± 14 30 ± 1 88.8 ± 4

The scripts to train the models must be run from the main directory, e.g.: ./scripts/mnist_75sp.sh

Examples of evaluating our trained models can be found in notebooks: MNIST_eval_models and TRIANGLES_eval_models.

Other examples of training models

To tune hyperparameters on the validation set for COLORS, TRIANGLES and MNIST, use the --validation flag.

For COLLAB, PROTEINS and D&D tuning of hyperparameters is included in the training script. Use the --ax flag.

Example of running 10 weakly-supervised experiments on PROTEINS with cross-validation of hyperparameters including initialization parameters (distribution and scale) of the attention model (the --tune_init flag):

for i in $(seq 1 1 10); do dataseed=$(( ( RANDOM % 10000 ) + 1 )); for j in $(seq 1 1 10); do seed=$(( ( RANDOM % 10000 ) + 1 )); python main.py --seed $seed -D TU --n_nodes 25 --epochs 50 --lr_decay_step 25,35,45 --test_batch_size 100 -f 64,64,64 -K 1 --readout max --dropout 0.1 --pool attn_sup_threshold_skip_skip_0 --pool_arch fc_prev --results None --data_dir ./data/PROTEINS --seed_data $dataseed --cv --cv_folds 5 --cv_threads 5 --ax --ax_trials 30 --scale None --tune_init | tee logs/proteins_wsup_"$dataseed"_"$seed".log; done; done

No initialization tuning on COLLAB:

for i in $(seq 1 1 10); do dataseed=$(( ( RANDOM % 10000 ) + 1 )); for j in $(seq 1 1 10); do seed=$(( ( RANDOM % 10000 ) + 1 )); python main.py --seed $seed -D TU --n_nodes 35 --epochs 50 --lr_decay_step 25,35,45 --test_batch_size 32 -f 64,64,64 -K 3 --readout max --dropout 0.1 --pool attn_sup_threshold_skip_skip_skip_0 --pool_arch fc_prev --results None --data_dir ./data/COLLAB --seed_data $dataseed --cv --cv_folds 5 --cv_threads 5 --ax --ax_trials 30 --scale None | tee logs/collab_wsup_"$dataseed"_"$seed".log; done; done

Note that results can be better if using --pool_arch gnn_prev, but we didn't focus on that.

Requirements

Python packages required (can be installed via pip or conda):

  • python >= 3.6.1
  • PyTorch >= 0.4.1
  • Ax for hyper-parameter tuning on COLLAB, PROTEINS and D&D
  • networkx
  • OpenCV
  • SciPy
  • scikit-image
  • scikit-learn

Reference

Please cite our paper if you use our data or code:

@inproceedings{knyazev2019understanding,
  title={Understanding attention and generalization in graph neural networks},
  author={Knyazev, Boris and Taylor, Graham W and Amer, Mohamed},
  booktitle={Advances in Neural Information Processing Systems},
  pages={4202--4212},
  year={2019},
  pdf={http://arxiv.org/abs/1905.02850}
}
A Python Package for Convex Regression and Frontier Estimation

pyStoNED pyStoNED is a Python package that provides functions for estimating multivariate convex regression, convex quantile regression, convex expect

Sheng Dai 17 Jan 08, 2023
Flax is a neural network ecosystem for JAX that is designed for flexibility.

Flax: A neural network library and ecosystem for JAX designed for flexibility Overview | Quick install | What does Flax look like? | Documentation See

Google 3.9k Jan 02, 2023
This is the official PyTorch implementation of the CVPR 2020 paper "TransMoMo: Invariance-Driven Unsupervised Video Motion Retargeting".

TransMoMo: Invariance-Driven Unsupervised Video Motion Retargeting Project Page | YouTube | Paper This is the official PyTorch implementation of the C

Zhuoqian Yang 330 Dec 11, 2022
The MATH Dataset

Measuring Mathematical Problem Solving With the MATH Dataset This is the repository for Measuring Mathematical Problem Solving With the MATH Dataset b

Dan Hendrycks 267 Dec 26, 2022
An open-source project for applying deep learning to medical scenarios

Auto Vaidya An open source solution for creating end-end web app for employing the power of deep learning in various clinical scenarios like implant d

Smaranjit Ghose 18 May 29, 2022
E2EDNA2 - An automated pipeline for simulation of DNA aptamers complexed with small molecules and short peptides

E2EDNA2 - An automated pipeline for simulation of DNA aptamers complexed with small molecules and short peptides

11 Nov 08, 2022
Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

730 Jan 09, 2023
Towers of Babel: Combining Images, Language, and 3D Geometry for Learning Multimodal Vision. ICCV 2021.

Towers of Babel: Combining Images, Language, and 3D Geometry for Learning Multimodal Vision Download links and PyTorch implementation of "Towers of Ba

Blakey Wu 40 Dec 14, 2022
Raster Vision is an open source Python framework for building computer vision models on satellite, aerial, and other large imagery sets

Raster Vision is an open source Python framework for building computer vision models on satellite, aerial, and other large imagery sets (including obl

Azavea 1.7k Dec 22, 2022
Yas CRNN model training - Yet Another Genshin Impact Scanner

Yas-Train Yet Another Genshin Impact Scanner 又一个原神圣遗物导出器 介绍 该仓库为 Yas 的模型训练程序 相关资料 MobileNetV3 CRNN 使用 假设你会设置基本的pytorch环境。 生成数据集 python main.py gen 训练

wormtql 18 Jan 08, 2023
git git《Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking》(CVPR 2021) GitHub:git2] 《Masksembles for Uncertainty Estimation》(CVPR 2021) GitHub:git3]

Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking Ning Wang, Wengang Zhou, Jie Wang, and Houqiang Li Accepted by CVPR

NingWang 236 Dec 22, 2022
UMPNet: Universal Manipulation Policy Network for Articulated Objects

UMPNet: Universal Manipulation Policy Network for Articulated Objects Zhenjia Xu, Zhanpeng He, Shuran Song Columbia University Robotics and Automation

Columbia Artificial Intelligence and Robotics Lab 33 Dec 03, 2022
This project provides an unsupervised framework for mining and tagging quality phrases on text corpora with pretrained language models (KDD'21).

UCPhrase: Unsupervised Context-aware Quality Phrase Tagging To appear on KDD'21...[pdf] This project provides an unsupervised framework for mining and

Xiaotao Gu 146 Dec 22, 2022
Adaptive, interpretable wavelets across domains (NeurIPS 2021)

Adaptive wavelets Wavelets which adapt given data (and optionally a pre-trained model). This yields models which are faster, more compressible, and mo

Yu Group 50 Dec 16, 2022
Official Implementation for the "An Empirical Investigation of 3D Anomaly Detection and Segmentation" paper.

An Empirical Investigation of 3D Anomaly Detection and Segmentation Project | Paper Official PyTorch Implementation for the "An Empirical Investigatio

Eliahu Horwitz 55 Dec 14, 2022
New approach to benchmark VQA models

VQA Benchmarking This repository contains the web application & the python interface to evaluate VQA models. Documentation Please see the documentatio

4 Jul 25, 2022
Steerable discovery of neural audio effects

Steerable discovery of neural audio effects Christian J. Steinmetz and Joshua D. Reiss Abstract Applications of deep learning for audio effects often

Christian J. Steinmetz 182 Dec 29, 2022
Fusion-DHL: WiFi, IMU, and Floorplan Fusion for Dense History of Locations in Indoor Environments

Fusion-DHL: WiFi, IMU, and Floorplan Fusion for Dense History of Locations in Indoor Environments Paper: arXiv (ICRA 2021) Video : https://youtu.be/CC

Sachini Herath 68 Jan 03, 2023
Convolutional Neural Network to detect deforestation in the Amazon Rainforest

Convolutional Neural Network to detect deforestation in the Amazon Rainforest This project is part of my final work as an Aerospace Engineering studen

5 Feb 17, 2022
Stacs-ci - A set of modules to enable integration of STACS with commonly used CI / CD systems

Static Token And Credential Scanner CI Integrations What is it? STACS is a YARA

STACS 18 Aug 04, 2022