A PyTorch implementation of "Capsule Graph Neural Network" (ICLR 2019).

Overview

CapsGNN

PWC codebeat badge repo sizebenedekrozemberczki

A PyTorch implementation of Capsule Graph Neural Network (ICLR 2019).

Abstract

The high-quality node embeddings learned from the Graph Neural Networks (GNNs) have been applied to a wide range of node-based applications and some of them have achieved state-of-the-art (SOTA) performance. However, when applying node embeddings learned from GNNs to generate graph embeddings, the scalar node representation may not suffice to preserve the node/graph properties efficiently, resulting in sub-optimal graph embeddings. Inspired by the Capsule Neural Network (CapsNet), we propose the Capsule Graph Neural Network (CapsGNN), which adopts the concept of capsules to address the weakness in existing GNN-based graph embeddings algorithms. By extracting node features in the form of capsules, routing mechanism can be utilized to capture important information at the graph level. As a result, our model generates multiple embeddings for each graph to capture graph properties from different aspects. The attention module incorporated in CapsGNN is used to tackle graphs with various sizes which also enables the model to focus on critical parts of the graphs. Our extensive evaluations with 10 graph-structured datasets demonstrate that CapsGNN has a powerful mechanism that operates to capture macroscopic properties of the whole graph by data-driven. It outperforms other SOTA techniques on several graph classification tasks, by virtue of the new instrument.

This repository provides a PyTorch implementation of CapsGNN as described in the paper:

Capsule Graph Neural Network. Zhang Xinyi, Lihui Chen. ICLR, 2019. [Paper]

The core Capsule Neural Network implementation adapted is available [here].

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torch-scatter     1.4.0
torch-sparse      0.4.3
torch-cluster     1.4.5
torch-geometric   1.3.2
torchvision       0.3.0

Datasets

The code takes graphs for training from an input folder where each graph is stored as a JSON. Graphs used for testing are also stored as JSON files. Every node id and node label has to be indexed from 0. Keys of dictionaries are stored strings in order to make JSON serialization possible.

Every JSON file has the following key-value structure:

{"edges": [[0, 1],[1, 2],[2, 3],[3, 4]],
 "labels": {"0": "A", "1": "B", "2": "C", "3": "A", "4": "B"},
 "target": 1}

The **edges** key has an edge list value which descibes the connectivity structure. The **labels** key has labels for each node which are stored as a dictionary -- within this nested dictionary labels are values, node identifiers are keys. The **target** key has an integer value which is the class membership.

Outputs

The predictions are saved in the `output/` directory. Each embedding has a header and a column with the graph identifiers. Finally, the predictions are sorted by the identifier column.

Options

Training a CapsGNN model is handled by the `src/main.py` script which provides the following command line arguments.

Input and output options

  --training-graphs   STR    Training graphs folder.      Default is `dataset/train/`.
  --testing-graphs    STR    Testing graphs folder.       Default is `dataset/test/`.
  --prediction-path   STR    Output predictions file.     Default is `output/watts_predictions.csv`.

Model options

  --epochs                      INT     Number of epochs.                  Default is 100.
  --batch-size                  INT     Number fo graphs per batch.        Default is 32.
  --gcn-filters                 INT     Number of filters in GCNs.         Default is 20.
  --gcn-layers                  INT     Number of GCNs chained together.   Default is 2.
  --inner-attention-dimension   INT     Number of neurons in attention.    Default is 20.  
  --capsule-dimensions          INT     Number of capsule neurons.         Default is 8.
  --number-of-capsules          INT     Number of capsules in layer.       Default is 8.
  --weight-decay                FLOAT   Weight decay of Adam.              Defatuls is 10^-6.
  --lambd                       FLOAT   Regularization parameter.          Default is 0.5.
  --theta                       FLOAT   Reconstruction loss weight.        Default is 0.1.
  --learning-rate               FLOAT   Adam learning rate.                Default is 0.01.

Examples

The following commands learn a model and save the predictions. Training a model on the default dataset:

$ python src/main.py

Training a CapsGNNN model for a 100 epochs.

$ python src/main.py --epochs 100

Changing the batch size.

$ python src/main.py --batch-size 128

License

Comments
  •  Coordinate Addition module & Routing

    Coordinate Addition module & Routing

    Hi, thanks for your codes of GapsGNN. And I have some questions about Coordinate Addition module and Routing.

    1. Do you use Coordinate Addition module in this codes?
    2. In /src/layers.py, line 137 : c_ij = torch.nn.functional.softmax(b_ij, dim=0) . At this time, b_ij.size(0) == 1, why use dim =0 ?

    Thanks again.

    opened by S-rz 4
  • Something about reshape

    Something about reshape

    Hi @benedekrozemberczki ! Thank you for your work!

    I have a question at line 61 and 62 of CapsGNN/src/capsgnn.py

    hidden_representations = torch.cat(tuple(hidden_representations)) hidden_representations = hidden_representations.view(1, self.args.gcn_layers, self.args.gcn_filters,-1)

    Why you directly reshape L*N,D to 1,L,D,N instead of using permutation after reshape, e.g

    hidden_representations = hidden_representations.view(1, self.args.gcn_layers, -1,self.args.gcn_filters).permute(0,1,3,2)

    Thank you for your help!

    opened by yanx27 4
  • Reproduce Issues

    Reproduce Issues

    Hi, thanks for your PyTorch codes of GapsGNN. I try to run the codes on NCI, DD, and other graph classification datasets, but it doesn't work (For example, training loss converges to 2.0, and test acc is about 50% on NCI1 after several iterations.) How should I do if I want to run these codes on NCI, DD and etc? Thanks again.

    opened by veophi 1
  • D&D dataset

    D&D dataset

    I notice some datasets in your paper such as D&D dataset. May I know how to obtain these datasets? The processed datasets would be appreciated. Thank you.

    opened by try-to-anything 1
  • Other datasets

    Other datasets

    I notice some datasets in your paper such as RE-M5K and RE-M12K. May I know how to obtain these datasets? The processed datasets would be appreciated. Thank you.

    opened by HongyangGao 1
  • Not able to install torch-scatter with torch 0.4.1

    Not able to install torch-scatter with torch 0.4.1

    Hello,

    Thanks for sharing the implementation.

    While I'm try to run your code I get some error for installing the environment. I have torch 0.4.1, but not able to install torch-scatter.Got the following error: fatal error: torch/extension.h: No such file or directory

    But I can successfully install them for torch 1.0.

    Is your code working for torch 1.0? Or how to install torch-scatter for torch 0.4.1?

    Details:

    $ pip list Package Version


    backcall 0.1.0
    certifi 2018.8.24
    .... torch 0.4.1.post2 torch-geometric 1.1.1
    torchfile 0.1.0
    torchvision 0.2.1
    tornado 5.1
    tqdm 4.31.1
    traitlets 4.3.2
    urllib3 1.23
    visdom 0.1.8.5
    vispy 0.5.3
    .... ....

    $pip install torch-scatter

    opened by jkuh626 1
  • how to repeat your expriments?

    how to repeat your expriments?

    Enumerating feature and target values.

    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 14754.82it/s]

    Training started.

    Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00, 1.90it/s] CapsGNN (Loss=0.7279): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.92it/s]

    Scoring.

    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 128.47it/s]

    Accuracy: 0.3333

    Accuracy is too small

    opened by robotzheng 1
  • default input dir for graphs is

    default input dir for graphs is "input"

    The README mentions the default train and test graphs to be in dataset/train and dataset/test, whereas they are in input/train and input/test respectively. The param_parser.py has the correct default paths nevertheless.

    opened by Utkarsh87 0
Releases(v_0001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Python implementation of Lightning-rod Agent, the Stack4Things board-side probe

Iotronic Lightning-rod Agent Python implementation of Lightning-rod Agent, the Stack4Things board-side probe. Free software: Apache 2.0 license Websit

2 May 19, 2022
Codebase for ECCV18 "The Sound of Pixels"

Sound-of-Pixels Codebase for ECCV18 "The Sound of Pixels". *This repository is under construction, but the core parts are already there. Environment T

Hang Zhao 318 Dec 20, 2022
Learning to Reconstruct 3D Manhattan Wireframes from a Single Image

Learning to Reconstruct 3D Manhattan Wireframes From a Single Image This repository contains the PyTorch implementation of the paper: Yichao Zhou, Hao

Yichao Zhou 50 Dec 27, 2022
Learning with Subset Stacking

Learning with Subset Stacking (LESS) LESS is a new supervised learning algorithm that is based on training many local estimators on subsets of a given

S. Ilker Birbil 19 Oct 04, 2022
An Unsupervised Graph-based Toolbox for Fraud Detection

An Unsupervised Graph-based Toolbox for Fraud Detection Introduction: UGFraud is an unsupervised graph-based fraud detection toolbox that integrates s

SafeGraph 99 Dec 11, 2022
PyTorch implementation of DirectCLR from paper Understanding Dimensional Collapse in Contrastive Self-supervised Learning

DirectCLR DirectCLR is a simple contrastive learning model for visual representation learning. It does not require a trainable projector as SimCLR. It

Meta Research 49 Dec 21, 2022
EfficientMPC - Efficient Model Predictive Control Implementation

efficientMPC Efficient Model Predictive Control Implementation The original algo

Vin 8 Dec 04, 2022
Its a Plant Leaf Disease Detection System based on Machine Learning.

My_Project_Code Its a Plant Leaf Disease Detection System based on Machine Learning. I have used Tomato Leaves Dataset from kaggle. This system detect

Sanskriti Sidola 3 Jun 15, 2022
Code, pre-trained models and saliency results for the paper "Boosting RGB-D Saliency Detection by Leveraging Unlabeled RGB Images".

Boosting RGB-D Saliency Detection by Leveraging Unlabeled RGB This repository is the official implementation of the paper. Our results comming soon in

Xiaoqiang Wang 8 May 22, 2022
Multi Agent Reinforcement Learning for ROS in 2D Simulation Environments

IROS21 information To test the code and reproduce the experiments, follow the installation steps in Installation.md. Afterwards, follow the steps in E

11 Oct 29, 2022
Detect roadway lanes using Python OpenCV for project during the 5th semester at DHBW Stuttgart for lecture in digital image processing.

Find Line Detection (Image Processing) Identifying lanes of the road is very common task that human driver performs. It's important to keep the vehicl

LMF 4 Jun 21, 2022
The repo of the preprinting paper "Labels Are Not Perfect: Inferring Spatial Uncertainty in Object Detection"

Inferring Spatial Uncertainty in Object Detection A teaser version of the code for the paper Labels Are Not Perfect: Inferring Spatial Uncertainty in

ZINING WANG 21 Mar 03, 2022
[ICCV 2021] Official PyTorch implementation for Deep Relational Metric Learning.

Ranking Models in Unlabeled New Environments Prerequisites This code uses the following libraries Python 3.7 NumPy PyTorch 1.7.0 + torchivision 0.8.1

Borui Zhang 39 Dec 10, 2022
BMVC 2021 Oral: code for BI-GCN: Boundary-Aware Input-Dependent Graph Convolution for Biomedical Image Segmentation

BMVC 2021 BI-GConv: Boundary-Aware Input-Dependent Graph Convolution for Biomedical Image Segmentation Necassary Dependencies: PyTorch 1.2.0 Python 3.

Yanda Meng 15 Nov 08, 2022
CLNTM - Contrastive Learning for Neural Topic Model

Contrastive Learning for Neural Topic Model This repository contains the impleme

Thong Thanh Nguyen 25 Nov 24, 2022
PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

Amin Rezaei 157 Dec 11, 2022
This is the official PyTorch implementation of our paper: "Artistic Style Transfer with Internal-external Learning and Contrastive Learning".

Artistic Style Transfer with Internal-external Learning and Contrastive Learning This is the official PyTorch implementation of our paper: "Artistic S

51 Dec 20, 2022
Shōgun

The SHOGUN machine learning toolbox Unified and efficient Machine Learning since 1999. Latest release: Cite Shogun: Develop branch build status: Donat

Shōgun ML 2.9k Jan 04, 2023
CTC segmentation python package

CTC segmentation CTC segmentation can be used to find utterances alignments within large audio files. This repository contains the ctc-segmentation py

Ludwig Kürzinger 217 Jan 04, 2023
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Tencent YouTu Research 50 Dec 03, 2022