A PyTorch implementation of "Capsule Graph Neural Network" (ICLR 2019).

Overview

CapsGNN

PWC codebeat badge repo sizebenedekrozemberczki

A PyTorch implementation of Capsule Graph Neural Network (ICLR 2019).

Abstract

The high-quality node embeddings learned from the Graph Neural Networks (GNNs) have been applied to a wide range of node-based applications and some of them have achieved state-of-the-art (SOTA) performance. However, when applying node embeddings learned from GNNs to generate graph embeddings, the scalar node representation may not suffice to preserve the node/graph properties efficiently, resulting in sub-optimal graph embeddings. Inspired by the Capsule Neural Network (CapsNet), we propose the Capsule Graph Neural Network (CapsGNN), which adopts the concept of capsules to address the weakness in existing GNN-based graph embeddings algorithms. By extracting node features in the form of capsules, routing mechanism can be utilized to capture important information at the graph level. As a result, our model generates multiple embeddings for each graph to capture graph properties from different aspects. The attention module incorporated in CapsGNN is used to tackle graphs with various sizes which also enables the model to focus on critical parts of the graphs. Our extensive evaluations with 10 graph-structured datasets demonstrate that CapsGNN has a powerful mechanism that operates to capture macroscopic properties of the whole graph by data-driven. It outperforms other SOTA techniques on several graph classification tasks, by virtue of the new instrument.

This repository provides a PyTorch implementation of CapsGNN as described in the paper:

Capsule Graph Neural Network. Zhang Xinyi, Lihui Chen. ICLR, 2019. [Paper]

The core Capsule Neural Network implementation adapted is available [here].

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torch-scatter     1.4.0
torch-sparse      0.4.3
torch-cluster     1.4.5
torch-geometric   1.3.2
torchvision       0.3.0

Datasets

The code takes graphs for training from an input folder where each graph is stored as a JSON. Graphs used for testing are also stored as JSON files. Every node id and node label has to be indexed from 0. Keys of dictionaries are stored strings in order to make JSON serialization possible.

Every JSON file has the following key-value structure:

{"edges": [[0, 1],[1, 2],[2, 3],[3, 4]],
 "labels": {"0": "A", "1": "B", "2": "C", "3": "A", "4": "B"},
 "target": 1}

The **edges** key has an edge list value which descibes the connectivity structure. The **labels** key has labels for each node which are stored as a dictionary -- within this nested dictionary labels are values, node identifiers are keys. The **target** key has an integer value which is the class membership.

Outputs

The predictions are saved in the `output/` directory. Each embedding has a header and a column with the graph identifiers. Finally, the predictions are sorted by the identifier column.

Options

Training a CapsGNN model is handled by the `src/main.py` script which provides the following command line arguments.

Input and output options

  --training-graphs   STR    Training graphs folder.      Default is `dataset/train/`.
  --testing-graphs    STR    Testing graphs folder.       Default is `dataset/test/`.
  --prediction-path   STR    Output predictions file.     Default is `output/watts_predictions.csv`.

Model options

  --epochs                      INT     Number of epochs.                  Default is 100.
  --batch-size                  INT     Number fo graphs per batch.        Default is 32.
  --gcn-filters                 INT     Number of filters in GCNs.         Default is 20.
  --gcn-layers                  INT     Number of GCNs chained together.   Default is 2.
  --inner-attention-dimension   INT     Number of neurons in attention.    Default is 20.  
  --capsule-dimensions          INT     Number of capsule neurons.         Default is 8.
  --number-of-capsules          INT     Number of capsules in layer.       Default is 8.
  --weight-decay                FLOAT   Weight decay of Adam.              Defatuls is 10^-6.
  --lambd                       FLOAT   Regularization parameter.          Default is 0.5.
  --theta                       FLOAT   Reconstruction loss weight.        Default is 0.1.
  --learning-rate               FLOAT   Adam learning rate.                Default is 0.01.

Examples

The following commands learn a model and save the predictions. Training a model on the default dataset:

$ python src/main.py

Training a CapsGNNN model for a 100 epochs.

$ python src/main.py --epochs 100

Changing the batch size.

$ python src/main.py --batch-size 128

License

Comments
  •  Coordinate Addition module & Routing

    Coordinate Addition module & Routing

    Hi, thanks for your codes of GapsGNN. And I have some questions about Coordinate Addition module and Routing.

    1. Do you use Coordinate Addition module in this codes?
    2. In /src/layers.py, line 137 : c_ij = torch.nn.functional.softmax(b_ij, dim=0) . At this time, b_ij.size(0) == 1, why use dim =0 ?

    Thanks again.

    opened by S-rz 4
  • Something about reshape

    Something about reshape

    Hi @benedekrozemberczki ! Thank you for your work!

    I have a question at line 61 and 62 of CapsGNN/src/capsgnn.py

    hidden_representations = torch.cat(tuple(hidden_representations)) hidden_representations = hidden_representations.view(1, self.args.gcn_layers, self.args.gcn_filters,-1)

    Why you directly reshape L*N,D to 1,L,D,N instead of using permutation after reshape, e.g

    hidden_representations = hidden_representations.view(1, self.args.gcn_layers, -1,self.args.gcn_filters).permute(0,1,3,2)

    Thank you for your help!

    opened by yanx27 4
  • Reproduce Issues

    Reproduce Issues

    Hi, thanks for your PyTorch codes of GapsGNN. I try to run the codes on NCI, DD, and other graph classification datasets, but it doesn't work (For example, training loss converges to 2.0, and test acc is about 50% on NCI1 after several iterations.) How should I do if I want to run these codes on NCI, DD and etc? Thanks again.

    opened by veophi 1
  • D&D dataset

    D&D dataset

    I notice some datasets in your paper such as D&D dataset. May I know how to obtain these datasets? The processed datasets would be appreciated. Thank you.

    opened by try-to-anything 1
  • Other datasets

    Other datasets

    I notice some datasets in your paper such as RE-M5K and RE-M12K. May I know how to obtain these datasets? The processed datasets would be appreciated. Thank you.

    opened by HongyangGao 1
  • Not able to install torch-scatter with torch 0.4.1

    Not able to install torch-scatter with torch 0.4.1

    Hello,

    Thanks for sharing the implementation.

    While I'm try to run your code I get some error for installing the environment. I have torch 0.4.1, but not able to install torch-scatter.Got the following error: fatal error: torch/extension.h: No such file or directory

    But I can successfully install them for torch 1.0.

    Is your code working for torch 1.0? Or how to install torch-scatter for torch 0.4.1?

    Details:

    $ pip list Package Version


    backcall 0.1.0
    certifi 2018.8.24
    .... torch 0.4.1.post2 torch-geometric 1.1.1
    torchfile 0.1.0
    torchvision 0.2.1
    tornado 5.1
    tqdm 4.31.1
    traitlets 4.3.2
    urllib3 1.23
    visdom 0.1.8.5
    vispy 0.5.3
    .... ....

    $pip install torch-scatter

    opened by jkuh626 1
  • how to repeat your expriments?

    how to repeat your expriments?

    Enumerating feature and target values.

    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 14754.82it/s]

    Training started.

    Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00, 1.90it/s] CapsGNN (Loss=0.7279): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.92it/s]

    Scoring.

    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 128.47it/s]

    Accuracy: 0.3333

    Accuracy is too small

    opened by robotzheng 1
  • default input dir for graphs is

    default input dir for graphs is "input"

    The README mentions the default train and test graphs to be in dataset/train and dataset/test, whereas they are in input/train and input/test respectively. The param_parser.py has the correct default paths nevertheless.

    opened by Utkarsh87 0
Releases(v_0001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Official Pytorch implementation of "Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021)

Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021) Official Pytorch implementation of Unbiased Classification

Youngkyu 17 Jan 01, 2023
Repository containing the PhD Thesis "Formal Verification of Deep Reinforcement Learning Agents"

Getting Started This repository contains the code used for the following publications: Probabilistic Guarantees for Safe Deep Reinforcement Learning (

Edoardo Bacci 5 Aug 31, 2022
PyTorch implementation of "Efficient Neural Architecture Search via Parameters Sharing"

Efficient Neural Architecture Search (ENAS) in PyTorch PyTorch implementation of Efficient Neural Architecture Search via Parameters Sharing. ENAS red

Taehoon Kim 2.6k Dec 31, 2022
Unsupervised Learning of Multi-Frame Optical Flow with Occlusions

This is a Pytorch implementation of Janai, J., Güney, F., Ranjan, A., Black, M. and Geiger, A., Unsupervised Learning of Multi-Frame Optical Flow with

Anurag Ranjan 110 Nov 02, 2022
[ICLR2021] Unlearnable Examples: Making Personal Data Unexploitable

Unlearnable Examples Code for ICLR2021 Spotlight Paper "Unlearnable Examples: Making Personal Data Unexploitable " by Hanxun Huang, Xingjun Ma, Sarah

Hanxun Huang 98 Dec 07, 2022
Data pipelines for both TensorFlow and PyTorch!

rapidnlp-datasets Data pipelines for both TensorFlow and PyTorch ! If you want to load public datasets, try: tensorflow/datasets huggingface/datasets

1 Dec 08, 2021
Some experiments with tennis player aging curves using Hilbert space GPs in PyMC. Only experimental for now.

NOTE: This is still being developed! Setup notes This document uses Jeff Sackmann's tennis data. You can obtain it as follows: git clone https://githu

Martin Ingram 1 Jan 20, 2022
A rule learning algorithm for the deduction of syndrome definitions from time series data.

README This project provides a rule learning algorithm for the deduction of syndrome definitions from time series data. Large parts of the algorithm a

0 Sep 24, 2021
Subnet Replacement Attack: Towards Practical Deployment-Stage Backdoor Attack on Deep Neural Networks

Subnet Replacement Attack: Towards Practical Deployment-Stage Backdoor Attack on Deep Neural Networks Official implementation of paper Towards Practic

Xiangyu Qi 8 Dec 30, 2022
Physics-Aware Training (PAT) is a method to train real physical systems with backpropagation.

Physics-Aware Training (PAT) is a method to train real physical systems with backpropagation. It was introduced in Wright, Logan G. & Onodera, Tatsuhiro et al. (2021)1 to train Physical Neural Networ

McMahon Lab 230 Jan 05, 2023
[ACL-IJCNLP 2021] "EarlyBERT: Efficient BERT Training via Early-bird Lottery Tickets"

EarlyBERT This is the official implementation for the paper in ACL-IJCNLP 2021 "EarlyBERT: Efficient BERT Training via Early-bird Lottery Tickets" by

VITA 13 May 11, 2022
Nsdf: A mesh SDF with just some code we can directly paste into our raymarcher

nsdf Representing SDFs of arbitrary meshes has been a bit tricky so far. Express

Jan Ivanecky 5 Feb 18, 2022
A collection of 100 Deep Learning images and visualizations

A collection of Deep Learning images and visualizations. The project has been developed by the AI Summer team and currently contains almost 100 images.

AI Summer 65 Sep 12, 2022
PyTorch framework for Deep Learning research and development.

Accelerated DL & RL PyTorch framework for Deep Learning research and development. It was developed with a focus on reproducibility, fast experimentati

Catalyst-Team 29 Jul 13, 2022
pyspark🍒🥭 is delicious,just eat it!😋😋

如何用10天吃掉pyspark? 🔥 🔥 《10天吃掉那只pyspark》 🚀

lyhue1991 578 Dec 30, 2022
Removing Inter-Experimental Variability from Functional Data in Systems Neuroscience

Removing Inter-Experimental Variability from Functional Data in Systems Neuroscience This repository is the official implementation of [https://www.bi

Eulerlab 6 Oct 09, 2022
TorchFlare is a simple, beginner-friendly, and easy-to-use PyTorch Framework train your models effortlessly.

TorchFlare TorchFlare is a simple, beginner-friendly and an easy-to-use PyTorch Framework train your models without much effort. It provides an almost

Atharva Phatak 85 Dec 26, 2022
DenseNet Implementation in Keras with ImageNet Pretrained Models

DenseNet-Keras with ImageNet Pretrained Models This is an Keras implementation of DenseNet with ImageNet pretrained weights. The weights are converted

Felix Yu 568 Oct 31, 2022
Internship Assessment Task for BaggageAI.

BaggageAI Internship Task Problem Statement: You are given two sets of images:- background and threat objects. Background images are the background x-

Arya Shah 10 Nov 14, 2022
This Deep Learning Model Predicts that from which disease you are suffering.

Deep-Learning-Project This Deep Learning Model Predicts that from which disease you are suffering. This Project Covers the Topics of Deep Learning Int

Jai Viral Doshi 0 Jan 20, 2022