Code for Understanding Pooling in Graph Neural Networks

Related tags

Deep LearningSRC
Overview

Select, Reduce, Connect

This repository contains the code used for the experiments of:

"Understanding Pooling in Graph Neural Networks"

Setup

Install TensorFlow and other dependencies:

pip install -r requirements.txt

Running experiments

Experiments are found in the following folders:

  • autoencoder/
  • spectral_similarity/
  • graph_classification/

Each folder has a bash script called run_all.sh that will reproduce the results reported in the paper.

To generate the plots and tables that we included in the paper, you can use the plots.py, plots_datasets.py, or tables.py found in the folders.

To run experiments for an individual pooling operator, you can use the run_[OPERATOR NAME].py scripts in each folder.

The pooling operators that we used for the experiments are found in layers/ (trainable) and modules/ (non-trainable). The GNN architectures used in the experiments are found in models/.

The SRCPool class

The core of this repository is the SRCPool class that implements a general interface to create SRC pooling layers with the Keras API.

Our implementation of MinCutPool, DiffPool, LaPool, Top-K, and SAGPool using the SRCPool class can be found in src/layers.

In general, SRC layers compute:

Where is a node equivariant selection function that computes the supernode assignments , is a permutation-invariant function to reduce the supernodes into the new node attributes, and is a permutation-invariant connection function that computes the links between the pooled nodes.

By extending this class, it is possible to create any pooling layer in the SRC framework.

Input

  • X: Tensor of shape ([batch], N, F) representing node features;
  • A: Tensor or SparseTensor of shape ([batch], N, N) representing the adjacency matrix;
  • I: (optional) Tensor of integers with shape (N, ) representing the batch index;

Output

  • X_pool: Tensor of shape ([batch], K, F), representing the node features of the output. K is the number of output nodes and depends on the specific pooling strategy;
  • A_pool: Tensor or SparseTensor of shape ([batch], K, K) representing the adjacency matrix of the output;
  • I_pool: (only if I was given as input) Tensor of integers with shape (K, ) representing the batch index of the output;
  • S_pool: (if return_sel=True) Tensor or SparseTensor representing the supernode assignments;

API

  • pool(X, A, I, **kwargs): pools the graph and returns the reduced node features and adjacency matrix. If the batch index I is not None, a reduced version of I will be returned as well. Any given kwargs will be passed as keyword arguments to select(), reduce() and connect() if any matching key is found. The mandatory arguments of pool() (X, A, and I) must be computed in call() by calling self.get_inputs(inputs).
  • select(X, A, I, **kwargs): computes supernode assignments mapping the nodes of the input graph to the nodes of the output.
  • reduce(X, S, **kwargs): reduces the supernodes to form the nodes of the pooled graph.
  • connect(A, S, **kwargs): connects the reduced supernodes.
  • reduce_index(I, S, **kwargs): helper function to reduce the batch index (only called if I is given as input).

When overriding any function of the API, it is possible to access the true number of nodes of the input (N) as a Tensor in the instance variable self.N (this is populated by self.get_inputs() at the beginning of call()).

Arguments:

  • return_sel: if True, the Tensor used to represent supernode assignments will be returned with X_pool, A_pool, and I_pool;
Owner
Daniele Grattarola
PhD student @ Università della Svizzera italiana
Daniele Grattarola
Repository to run object detection on a model trained on an autonomous driving dataset.

Autonomous Driving Object Detection on the Raspberry Pi 4 Description of Repository This repository contains code and instructions to configure the ne

Ethan 51 Nov 17, 2022
Aerial Imagery dataset for fire detection: classification and segmentation (Unmanned Aerial Vehicle (UAV))

Aerial Imagery dataset for fire detection: classification and segmentation using Unmanned Aerial Vehicle (UAV) Title FLAME (Fire Luminosity Airborne-b

79 Jan 06, 2023
Official PyTorch implementation of paper: Standardized Max Logits: A Simple yet Effective Approach for Identifying Unexpected Road Obstacles in Urban-Scene Segmentation (ICCV 2021 Oral Presentation)

SML (ICCV 2021, Oral) : Official Pytorch Implementation This repository provides the official PyTorch implementation of the following paper: Standardi

SangHun 61 Dec 27, 2022
Neural Oblivious Decision Ensembles

Neural Oblivious Decision Ensembles A supplementary code for anonymous ICLR 2020 submission. What does it do? It learns deep ensembles of oblivious di

25 Sep 21, 2022
Links to works on deep learning algorithms for physics problems, TUM-I15 and beyond

Links to works on deep learning algorithms for physics problems, TUM-I15 and beyond

Nils Thuerey 1.3k Jan 08, 2023
A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

Zain 1 Feb 01, 2022
Personal implementation of paper "Approximate Nearest Neighbor Negative Contrastive Learning for Dense Text Retrieval"

Approximate Nearest Neighbor Negative Contrastive Learning for Dense Text Retrieval This repo provides personal implementation of paper Approximate Ne

John 8 Oct 07, 2022
Deep Probabilistic Programming Course @ DIKU

Deep Probabilistic Programming Course @ DIKU

52 May 14, 2022
Model-based reinforcement learning in TensorFlow

Bellman Website | Twitter | Documentation (latest) What does Bellman do? Bellman is a package for model-based reinforcement learning (MBRL) in Python,

46 Nov 09, 2022
Navigating StyleGAN2 w latent space using CLIP

Navigating StyleGAN2 w latent space using CLIP an attempt to build sth with the official SG2-ADA Pytorch impl kinda inspired by Generating Images from

Mike K. 55 Dec 06, 2022
This repository is to support contributions for tools for the Project CodeNet dataset hosted in DAX

The goal of Project CodeNet is to provide the AI-for-Code research community with a large scale, diverse, and high quality curated dataset to drive innovation in AI techniques.

International Business Machines 1.2k Jan 04, 2023
A Deep Learning Based Knowledge Extraction Toolkit for Knowledge Base Population

DeepKE is a knowledge extraction toolkit supporting low-resource and document-level scenarios for entity, relation and attribute extraction. We provide comprehensive documents, Google Colab tutorials

ZJUNLP 1.6k Jan 05, 2023
Object Detection with YOLOv3

Object Detection with YOLOv3 Bu projede YOLOv3-608 modeli kullanılmıştır. Requirements Python 3.8 OpenCV Numpy Documentation Yolo ile ilgili detaylı b

Ayşe Konuş 0 Mar 27, 2022
Dimension Reduced Turbulent Flow Data From Deep Vector Quantizers

Dimension Reduced Turbulent Flow Data From Deep Vector Quantizers This is an implementation of A Physics-Informed Vector Quantized Autoencoder for Dat

DreamSoul 3 Sep 12, 2022
Image Processing, Image Smoothing, Edge Detection and Transforms

opevcvdl-hw1 This project uses openCV and Qt to achieve the requirements. Version Python 3.7 opencv-contrib-python 3.4.2.17 Matplotlib 3.1.1 pyqt5 5.1

Kenny Cheng 3 Aug 17, 2022
ADGAN - The Implementation of paper Controllable Person Image Synthesis with Attribute-Decomposed GAN

ADGAN - The Implementation of paper Controllable Person Image Synthesis with Attribute-Decomposed GAN CVPR 2020 (Oral); Pose and Appearance Attributes Transfer;

Men Yifang 400 Dec 29, 2022
Compact Bidirectional Transformer for Image Captioning

Compact Bidirectional Transformer for Image Captioning Requirements Python 3.8 Pytorch 1.6 lmdb h5py tensorboardX Prepare Data Please use git clone --

YE Zhou 19 Dec 12, 2022
Latte: Cross-framework Python Package for Evaluation of Latent-based Generative Models

Cross-framework Python Package for Evaluation of Latent-based Generative Models Latte Latte (for LATent Tensor Evaluation) is a cross-framework Python

Karn Watcharasupat 30 Sep 08, 2022
A large dataset of 100k Google Satellite and matching Map images, resembling pix2pix's Google Maps dataset.

Larger Google Sat2Map dataset This dataset extends the aerial ⟷ Maps dataset used in pix2pix (Isola et al., CVPR17). The provide script download_sat2m

34 Dec 28, 2022
Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation

DynaBOA Code repositoty for the paper: Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation Shanyan Guan, Jingwei Xu, Michell

198 Dec 29, 2022