Code for Understanding Pooling in Graph Neural Networks

Related tags

Deep LearningSRC
Overview

Select, Reduce, Connect

This repository contains the code used for the experiments of:

"Understanding Pooling in Graph Neural Networks"

Setup

Install TensorFlow and other dependencies:

pip install -r requirements.txt

Running experiments

Experiments are found in the following folders:

  • autoencoder/
  • spectral_similarity/
  • graph_classification/

Each folder has a bash script called run_all.sh that will reproduce the results reported in the paper.

To generate the plots and tables that we included in the paper, you can use the plots.py, plots_datasets.py, or tables.py found in the folders.

To run experiments for an individual pooling operator, you can use the run_[OPERATOR NAME].py scripts in each folder.

The pooling operators that we used for the experiments are found in layers/ (trainable) and modules/ (non-trainable). The GNN architectures used in the experiments are found in models/.

The SRCPool class

The core of this repository is the SRCPool class that implements a general interface to create SRC pooling layers with the Keras API.

Our implementation of MinCutPool, DiffPool, LaPool, Top-K, and SAGPool using the SRCPool class can be found in src/layers.

In general, SRC layers compute:

Where is a node equivariant selection function that computes the supernode assignments , is a permutation-invariant function to reduce the supernodes into the new node attributes, and is a permutation-invariant connection function that computes the links between the pooled nodes.

By extending this class, it is possible to create any pooling layer in the SRC framework.

Input

  • X: Tensor of shape ([batch], N, F) representing node features;
  • A: Tensor or SparseTensor of shape ([batch], N, N) representing the adjacency matrix;
  • I: (optional) Tensor of integers with shape (N, ) representing the batch index;

Output

  • X_pool: Tensor of shape ([batch], K, F), representing the node features of the output. K is the number of output nodes and depends on the specific pooling strategy;
  • A_pool: Tensor or SparseTensor of shape ([batch], K, K) representing the adjacency matrix of the output;
  • I_pool: (only if I was given as input) Tensor of integers with shape (K, ) representing the batch index of the output;
  • S_pool: (if return_sel=True) Tensor or SparseTensor representing the supernode assignments;

API

  • pool(X, A, I, **kwargs): pools the graph and returns the reduced node features and adjacency matrix. If the batch index I is not None, a reduced version of I will be returned as well. Any given kwargs will be passed as keyword arguments to select(), reduce() and connect() if any matching key is found. The mandatory arguments of pool() (X, A, and I) must be computed in call() by calling self.get_inputs(inputs).
  • select(X, A, I, **kwargs): computes supernode assignments mapping the nodes of the input graph to the nodes of the output.
  • reduce(X, S, **kwargs): reduces the supernodes to form the nodes of the pooled graph.
  • connect(A, S, **kwargs): connects the reduced supernodes.
  • reduce_index(I, S, **kwargs): helper function to reduce the batch index (only called if I is given as input).

When overriding any function of the API, it is possible to access the true number of nodes of the input (N) as a Tensor in the instance variable self.N (this is populated by self.get_inputs() at the beginning of call()).

Arguments:

  • return_sel: if True, the Tensor used to represent supernode assignments will be returned with X_pool, A_pool, and I_pool;
Owner
Daniele Grattarola
PhD student @ Università della Svizzera italiana
Daniele Grattarola
[ICCV 2021 Oral] PoinTr: Diverse Point Cloud Completion with Geometry-Aware Transformers

PoinTr: Diverse Point Cloud Completion with Geometry-Aware Transformers Created by Xumin Yu*, Yongming Rao*, Ziyi Wang, Zuyan Liu, Jiwen Lu, Jie Zhou

Xumin Yu 317 Dec 26, 2022
Fast mesh denoising with data driven normal filtering using deep variational autoencoders

Fast mesh denoising with data driven normal filtering using deep variational autoencoders This is an implementation for the paper entitled "Fast mesh

9 Dec 02, 2022
Code of 3D Shape Variational Autoencoder Latent Disentanglement via Mini-Batch Feature Swapping for Bodies and Faces

3D Shape Variational Autoencoder Latent Disentanglement via Mini-Batch Feature Swapping for Bodies and Faces Installation After cloning the repo open

37 Dec 03, 2022
Scenarios, tutorials and demos for Autonomous Driving

The Autonomous Driving Cookbook (Preview) NOTE: This project is developed and being maintained by Project Road Runner at Microsoft Garage. This is cur

Microsoft 2.1k Jan 02, 2023
Generative Adversarial Text-to-Image Synthesis

###Generative Adversarial Text-to-Image Synthesis Scott Reed, Zeynep Akata, Xinchen Yan, Lajanugen Logeswaran, Bernt Schiele, Honglak Lee This is the

Scott Ellison Reed 883 Dec 31, 2022
A privacy-focused, intelligent security camera system.

Self-Hosted Home Security Camera System A privacy-focused, intelligent security camera system. Features: Multi-camera support w/ minimal configuration

Scott Barnes 175 Jan 01, 2023
Neural Surface Maps

Neural Surface Maps Official implementation of Neural Surface Maps - Luca Morreale, Noam Aigerman, Vladimir Kim, Niloy J. Mitra [Paper] [Project Page]

Luca Morreale 49 Dec 13, 2022
Semi-supervised Implicit Scene Completion from Sparse LiDAR

Semi-supervised Implicit Scene Completion from Sparse LiDAR Paper Created by Pengfei Li, Yongliang Shi, Tianyu Liu, Hao Zhao, Guyue Zhou and YA-QIN ZH

114 Nov 30, 2022
Sentinel-1 vessel detection model used in the xView3 challenge

sar_vessel_detect Code for the AI2 Skylight team's submission in the xView3 competition (https://iuu.xview.us) for vessel detection in Sentinel-1 SAR

AI2 6 Sep 10, 2022
OpenMMLab Computer Vision Foundation

English | 简体中文 Introduction MMCV is a foundational library for computer vision research and supports many research projects as below: MMCV: OpenMMLab

OpenMMLab 4.6k Jan 09, 2023
[ICLR 2021] "CPT: Efficient Deep Neural Network Training via Cyclic Precision" by Yonggan Fu, Han Guo, Meng Li, Xin Yang, Yining Ding, Vikas Chandra, Yingyan Lin

CPT: Efficient Deep Neural Network Training via Cyclic Precision Yonggan Fu, Han Guo, Meng Li, Xin Yang, Yining Ding, Vikas Chandra, Yingyan Lin Accep

26 Oct 25, 2022
Space Time Recurrent Memory Network - Pytorch

Space Time Recurrent Memory Network - Pytorch (wip) Implementation of Space Time Recurrent Memory Network, recurrent network competitive with attentio

Phil Wang 50 Nov 07, 2021
S2-BNN: Bridging the Gap Between Self-Supervised Real and 1-bit Neural Networks via Guided Distribution Calibration (CVPR 2021)

S2-BNN (Self-supervised Binary Neural Networks Using Distillation Loss) This is the official pytorch implementation of our paper: "S2-BNN: Bridging th

Zhiqiang Shen 52 Dec 24, 2022
Codes and Data Processing Files for our paper.

Code Scripts and Processing Files for EEG Sleep Staging Paper 1. Folder Tree ./src_preprocess (data preprocessing files for SHHS and Sleep EDF) sleepE

Chaoqi Yang 18 Dec 12, 2022
This is the 3D Implementation of 《Inconsistency-aware Uncertainty Estimation for Semi-supervised Medical Image Segmentation》

CoraNet This is the 3D Implementation of 《Inconsistency-aware Uncertainty Estimation for Semi-supervised Medical Image Segmentation》 Environment pytor

25 Nov 08, 2022
Quickly and easily create / train a custom DeepDream model

Dream-Creator This project aims to simplify the process of creating a custom DeepDream model by using pretrained GoogleNet models and custom image dat

55 Dec 27, 2022
CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper)

CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper) (Accepted for oral presentation at ACM

Minha Kim 1 Nov 12, 2021
Code and models for "Pano3D: A Holistic Benchmark and a Solid Baseline for 360 Depth Estimation", OmniCV Workshop @ CVPR21.

Pano3D A Holistic Benchmark and a Solid Baseline for 360o Depth Estimation Pano3D is a new benchmark for depth estimation from spherical panoramas. We

Visual Computing Lab, Information Technologies Institute, Centre for Reseach and Technology Hellas 50 Dec 29, 2022
Self-attentive task GAN for space domain awareness data augmentation.

SATGAN TODO: update the article URL once published. Article about this implemention The self-attentive task generative adversarial network (SATGAN) le

Nathan 2 Mar 24, 2022
CVPR2021 Workshop - HDRUNet: Single Image HDR Reconstruction with Denoising and Dequantization.

HDRUNet [Paper Link] HDRUNet: Single Image HDR Reconstruction with Denoising and Dequantization By Xiangyu Chen, Yihao Liu, Zhengwen Zhang, Yu Qiao an

XyChen 105 Dec 20, 2022