A PyTorch implementation of "Graph Classification Using Structural Attention" (KDD 2018).

Overview

GAM

PWC codebeat badge repo sizebenedekrozemberczki

A PyTorch implementation of Graph Classification Using Structural Attention (KDD 2018).

Abstract

Graph classification is a problem with practical applications in many different domains. To solve this problem, one usually calculates certain graph statistics (i.e., graph features) that help discriminate between graphs of different classes. When calculating such features, most existing approaches process the entire graph. In a graphlet-based approach, for instance, the entire graph is processed to get the total count of different graphlets or subgraphs. In many real-world applications, however, graphs can be noisy with discriminative patterns confined to certain regions in the graph only. In this work, we study the problem of attention-based graph classification . The use of attention allows us to focus on small but informative parts of the graph, avoiding noise in the rest of the graph. We present a novel RNN model, called the Graph Attention Model (GAM), that processes only a portion of the graph by adaptively selecting a sequence of “informative” nodes. Experimental results on multiple real-world datasets show that the proposed method is competitive against various well-known methods in graph classification even though our method is limited to only a portion of the graph.

This repository provides an implementation for GAM as described in the paper:

Graph Classification using Structural Attention. John Boaz Lee, Ryan Rossi, and Xiangnan Kong KDD, 2018. [Paper]

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx           2.4
tqdm               4.28.1
numpy              1.15.4
pandas             0.23.4
texttable          1.5.0
argparse           1.1.0
sklearn            0.20.0
torch              1.2.0.
torchvision        0.3.0

Datasets

The code takes graphs for training from an input folder where each graph is stored as a JSON. Graphs used for testing are also stored as JSON files. Every node id, node label and class has to be indexed from 0. Keys of dictionaries and nested dictionaries are stored strings in order to make JSON serialization possible.

For example these JSON files have the following key-value structure:

{"target": 1,
 "edges": [[0, 1], [0, 4], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]],
 "labels": {"0": 2, "1": 3, "2": 2, "3": 3, "4": 4},
 "inverse_labels": {"2": [0, 2], "3": [1, 3], "4": [4]}}

The **target key** has an integer value, which is the ID of the target class (e.g. Carcinogenicity). The **edges key** has an edge list value for the graph of interest. The **labels key** has a dictonary value for each node, these labels are stored as key-value pairs (e.g. node - atom pair). The **inverse_labels key** has a key for each node label and the values are lists containing the nodes that have a specific node label.

Options

Training a GAM model is handled by the src/main.py script which provides the following command line arguments.

Input and output options

  --train-graph-folder   STR    Training graphs folder.      Default is `input/train/`.
  --test-graph-folder    STR    Testing graphs folder.       Default is `input/test/`.
  --prediction-path      STR    Path to store labels.        Default is `output/erdos_predictions.csv`.
  --log-path             STR    Log json path.               Default is `logs/erdos_gam_logs.json`. 

Model options

  --repetitions          INT         Number of scoring runs.                  Default is 10. 
  --batch-size           INT         Number of graphs processed per batch.    Default is 32. 
  --time                 INT         Time budget.                             Default is 20. 
  --step-dimensions      INT         Neurons in step layer.                   Default is 32. 
  --combined-dimensions  INT         Neurons in shared layer.                 Default is 64. 
  --epochs               INT         Number of GAM training epochs.           Default is 10. 
  --learning-rate        FLOAT       Learning rate.                           Default is 0.001.
  --gamma                FLOAT       Discount rate.                           Default is 0.99. 
  --weight-decay         FLOAT       Weight decay.                            Default is 10^-5. 

Examples

The following commands learn a neural network, make predictions, create logs, and write the latter ones to disk.

Training a GAM model on the default dataset. Saving predictions and logs at default paths.

python src/main.py

Training a GAM model for a 100 epochs with a batch size of 512.

python src/main.py --epochs 100 --batch-size 512

Setting a high time budget for the agent.

python src/main.py --time 128

Training a model with some custom learning rate and epoch number.

python src/main.py --learning-rate 0.001 --epochs 200

License


Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
generate-2D-quadrilateral-mesh-with-neural-networks-and-tree-search

generate-2D-quadrilateral-mesh-with-neural-networks-and-tree-search This repository contains single-threaded TreeMesh code. I'm Hua Tong, a senior stu

Hua Tong 18 Sep 21, 2022
General purpose Slater-Koster tight-binding code for electronic structure calculations

tight-binder Introduction General purpose tight-binding code for electronic structure calculations based on the Slater-Koster approximation. The code

9 Dec 15, 2022
Code for the paper Progressive Pose Attention for Person Image Generation in CVPR19 (Oral).

Pose-Transfer Code for the paper Progressive Pose Attention for Person Image Generation in CVPR19(Oral). The paper is available here. Video generation

Tengteng Huang 679 Jan 04, 2023
Adaptive Attention Span for Reinforcement Learning

Adaptive Transformers in RL Official implementation of Adaptive Transformers in RL In this work we replicate several results from Stabilizing Transfor

100 Nov 15, 2022
Supplementary materials for ISMIR 2021 LBD paper "Evaluation of Latent Space Disentanglement in the Presence of Interdependent Attributes"

Evaluation of Latent Space Disentanglement in the Presence of Interdependent Attributes Supplementary materials for ISMIR 2021 LBD submission: K. N. W

Karn Watcharasupat 2 Oct 25, 2021
The PyTorch re-implement of a 3D CNN Tracker to extract coronary artery centerlines with state-of-the-art (SOTA) performance. (paper: 'Coronary artery centerline extraction in cardiac CT angiography using a CNN-based orientation classifier')

The PyTorch re-implement of a 3D CNN Tracker to extract coronary artery centerlines with state-of-the-art (SOTA) performance. (paper: 'Coronary artery centerline extraction in cardiac CT angiography

James 135 Dec 23, 2022
Source Code for Simulations in the Publication "Can the brain use waves to solve planning problems?"

Code for Simulations in the Publication Can the brain use waves to solve planning problems? Installing Required Python Packages Please use Python vers

EMD Group 2 Jul 01, 2022
Synthetic Humans for Action Recognition, IJCV 2021

SURREACT: Synthetic Humans for Action Recognition from Unseen Viewpoints Gül Varol, Ivan Laptev and Cordelia Schmid, Andrew Zisserman, Synthetic Human

Gul Varol 59 Dec 14, 2022
Visualizer using audio and semantic analysis to explore BigGAN (Brock et al., 2018) latent space.

BigGAN Audio Visualizer Description This visualizer explores BigGAN (Brock et al., 2018) latent space by using pitch/tempo of an audio file to generat

Rush Kapoor 2 Nov 21, 2022
Semi-Autoregressive Transformer for Image Captioning

Semi-Autoregressive Transformer for Image Captioning Requirements Python 3.6 Pytorch 1.6 Prepare data Please use git clone --recurse-submodules to clo

YE Zhou 23 Dec 09, 2022
Here is the diagnostic tool for BMVC 2021 paper Diagnosing Errors in Video Relation Detectors.

Here is the diagnostic tool for BMVC 2021 paper Diagnosing Errors in Video Relation Detectors. We provide a tiny ground truth file demo_gt.json, and t

Shuo Chen 3 Dec 26, 2022
免费获取http代理并生成proxifier配置文件

freeproxy 免费获取http代理并生成proxifier配置文件 公众号:台下言书 工具说明:https://mp.weixin.qq.com/s?__biz=MzIyNDkwNjQ5Ng==&mid=2247484425&idx=1&sn=56ccbe130822aa35038095317

说书人 32 Mar 25, 2022
pytorch implementation of fast-neural-style

fast-neural-style 🌇 🚀 NOTICE: This codebase is no longer maintained, please use the codebase from pytorch examples repository available at pytorch/e

Abhishek Kadian 405 Dec 15, 2022
Simple (but Strong) Baselines for POMDPs

Recurrent Model-Free RL is a Strong Baseline for Many POMDPs Welcome to the POMDP world! This repo provides some simple baselines for POMDPs, specific

Tianwei V. Ni 172 Dec 29, 2022
Editing a Conditional Radiance Field

Editing Conditional Radiance Fields Project | Paper | Video | Demo Editing Conditional Radiance Fields Steven Liu, Xiuming Zhang, Zhoutong Zhang, Rich

Steven Liu 216 Dec 30, 2022
JAX + dataclasses

jax_dataclasses jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for: Pytree registrati

Brent Yi 35 Dec 21, 2022
Source code for "MusCaps: Generating Captions for Music Audio" (IJCNN 2021)

MusCaps: Generating Captions for Music Audio Ilaria Manco1 2, Emmanouil Benetos1, Elio Quinton2, Gyorgy Fazekas1 1 Queen Mary University of London, 2

Ilaria Manco 57 Dec 07, 2022
Notebook and code to synthesize complex and highly dimensional datasets using Gretel APIs.

Gretel Trainer This code is designed to help users successfully train synthetic models on complex datasets with high row and column counts. The code w

Gretel.ai 24 Nov 03, 2022
Repository of 3D Object Detection with Pointformer (CVPR2021)

3D Object Detection with Pointformer This repository contains the code for the paper 3D Object Detection with Pointformer (CVPR 2021) [arXiv]. This wo

Zhuofan Xia 117 Jan 06, 2023
PipeTransformer: Automated Elastic Pipelining for Distributed Training of Large-scale Models

PipeTransformer: Automated Elastic Pipelining for Distributed Training of Large-scale Models This repository is the official implementation of the fol

DistributedML 41 Dec 06, 2022