Implementation of "GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings" in PyTorch

Overview

PyGAS: Auto-Scaling GNNs in PyG


PyGAS is the practical realization of our GNNAutoScale (GAS) framework, which scales arbitrary message-passing GNNs to large graphs, as described in our paper:

Matthias Fey, Jan E. Lenssen, Frank Weichert, Jure Leskovec: GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings (ICML 2021)

GAS prunes entire sub-trees of the computation graph by utilizing historical embeddings from prior training iterations, leading to constant GPU memory consumption in respect to input mini-batch size, and maximally expressivity.

PyGAS is implemented in PyTorch and utilizes the PyTorch Geometric (PyG) library. It provides an easy-to-use interface to convert a common or custom GNN from PyG into its scalable variant:

from torch_geometric.nn import SAGEConv
from torch_geometric_autoscale import ScalableGNN
from torch_geometric_autoscale import metis, permute, SubgraphLoader

class GNN(ScalableGNN):
    def __init__(self, num_nodes, in_channels, hidden_channels,
                 out_channels, num_layers):
        # * pool_size determines the number of pinned CPU buffers
        # * buffer_size determines the size of pinned CPU buffers,
        #   i.e. the maximum number of out-of-mini-batch nodes

        super().__init__(num_nodes, hidden_channels, num_layers,
                         pool_size=2, buffer_size=5000)

        self.convs = ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, adj_t, *args):
        for conv, history in zip(self.convs[:-1], self.histories):
            x = conv(x, adj_t).relu_()
            x = self.push_and_pull(history, x, *args)
        return self.convs[-1](x, adj_t)

perm, ptr = metis(data.adj_t, num_parts=40, log=True)
data = permute(data, perm, log=True)
loader = SubgraphLoader(data, ptr, batch_size=10, shuffle=True)

model = GNN(...)
for batch, *args in loader:
    out = model(batch.x, batch.adj_t, *args)

A detailed description of ScalableGNN can be found in its implementation.

Requirements

pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-geometric

where ${TORCH} should be replaced by either 1.7.0 or 1.8.0, and ${CUDA} should be replaced by either cpu, cu92, cu101, cu102, cu110 or cu111, depending on your PyTorch installation.

Installation

pip install git+https://github.com/rusty1s/pyg_autoscale.git

or

python setup.py install

Project Structure

  • torch_geometric_autoscale/ contains the source code of PyGAS
  • examples/ contains examples to demonstrate how to apply GAS in practice
  • small_benchmark/ includes experiments to evaluate GAS performance on small-scale graphs
  • large_benchmark/ includes experiments to evaluate GAS performance on large-scale graphs

We use Hydra to manage hyperparameter configurations.

Cite

Please cite our paper if you use this code in your own work:

@inproceedings{Fey/etal/2021,
  title={{GNNAutoScale}: Scalable and Expressive Graph Neural Networks via Historical Embeddings},
  author={Fey, M. and Lenssen, J. E. and Weichert, F. and Leskovec, J.},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2021},
}
Owner
Matthias Fey
PhD student @ TU Dortmund University - Interested in Representation Learning on Graphs and Manifolds; PyTorch, CUDA, Vim and macOS Enthusiast
Matthias Fey
Migration of Edge-based Distributed Federated Learning

FedFly: Towards Migration in Edge-based Distributed Federated Learning About the research Due to mobility, a device participating in Federated Learnin

qub-blesson 11 Nov 13, 2022
Measures input lag without dedicated hardware, performing motion detection on recorded or live video

What is InputLagTimer? This tool can measure input lag by analyzing a video where both the game controller and the game screen can be seen on a webcam

Bruno Gonzalez 4 Aug 18, 2022
This repo contains code to reproduce all experiments in Equivariant Neural Rendering

Equivariant Neural Rendering This repo contains code to reproduce all experiments in Equivariant Neural Rendering by E. Dupont, M. A. Bautista, A. Col

Apple 83 Nov 16, 2022
realsense d400 -> jpg + csv

Realsense-capture realsense d400 - jpg + csv Requirements RealSense sdk : Installation Python3 pyrealsense2 (RealSense SDK) Numpy OpenCV Tkinter Run

Ar-Ray 2 Mar 22, 2022
Pytorch implementation of U-Net, R2U-Net, Attention U-Net, and Attention R2U-Net.

pytorch Implementation of U-Net, R2U-Net, Attention U-Net, Attention R2U-Net U-Net: Convolutional Networks for Biomedical Image Segmentation https://a

leejunhyun 2k Jan 02, 2023
Versatile Generative Language Model

Versatile Generative Language Model This is the implementation of the paper: Exploring Versatile Generative Language Model Via Parameter-Efficient Tra

Zhaojiang Lin 17 Dec 02, 2022
Implementation of SETR model, Original paper: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.

SETR - Pytorch Since the original paper (Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.) has no official

zhaohu xing 112 Dec 16, 2022
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021

Hypercorrelation Squeeze for Few-Shot Segmentation This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juh

Juhong Min 165 Dec 28, 2022
[ICML 2021] DouZero: Mastering DouDizhu with Self-Play Deep Reinforcement Learning | 斗地主AI

[ICML 2021] DouZero: Mastering DouDizhu with Self-Play Deep Reinforcement Learning DouZero is a reinforcement learning framework for DouDizhu (斗地主), t

Kwai Inc. 3.1k Jan 04, 2023
Most popular metrics used to evaluate object detection algorithms.

Most popular metrics used to evaluate object detection algorithms.

Rafael Padilla 4.4k Dec 25, 2022
Deep Learning GPU Training System

DIGITS DIGITS (the Deep Learning GPU Training System) is a webapp for training deep learning models. The currently supported frameworks are: Caffe, To

NVIDIA Corporation 4.1k Jan 03, 2023
Open-Domain Question-Answering for COVID-19 and Other Emergent Domains

Open-Domain Question-Answering for COVID-19 and Other Emergent Domains This repository contains the source code for an end-to-end open-domain question

7 Sep 27, 2022
Faster RCNN pytorch windows

Faster-RCNN-pytorch-windows Faster RCNN implementation with pytorch for windows Open cmd, compile this comands: cd lib python setup.py build develop T

Hwa-Rang Kim 1 Nov 11, 2022
Code for 2021 NeurIPS --- Towards Multi-Grained Explainability for Graph Neural Networks

ReFine: Multi-Grained Explainability for GNNs This is the official code for Towards Multi-Grained Explainability for Graph Neural Networks (NeurIPS 20

Shirley (Ying-Xin) Wu 47 Dec 16, 2022
A Python package for generating concise, high-quality summaries of a probability distribution

GoodPoints A Python package for generating concise, high-quality summaries of a probability distribution GoodPoints is a collection of tools for compr

Microsoft 28 Oct 10, 2022
This tutorial aims to learn the basics of deep learning by hands, and master the basics through combination of lectures and exercises

2021-Deep-learning This tutorial aims to learn the basics of deep learning by hands, and master the basics through combination of paper and exercises.

108 Feb 24, 2022
TalkNet 2: Non-Autoregressive Depth-Wise Separable Convolutional Model for Speech Synthesis with Explicit Pitch and Duration Prediction.

TalkNet 2 [WIP] TalkNet 2: Non-Autoregressive Depth-Wise Separable Convolutional Model for Speech Synthesis with Explicit Pitch and Duration Predictio

Rishikesh (ऋषिकेश) 69 Dec 17, 2022
Listing arxiv - Personalized list of today's articles from ArXiv

Personalized list of today's articles from ArXiv Print and/or send to your gmail

Lilianne Nakazono 5 Jun 17, 2022
Contrastive Learning of Structured World Models

Contrastive Learning of Structured World Models This repository contains the official PyTorch implementation of: Contrastive Learning of Structured Wo

Thomas Kipf 371 Jan 06, 2023
LogAvgExp - Pytorch Implementation of LogAvgExp

LogAvgExp - Pytorch Implementation of LogAvgExp for Pytorch Install $ pip instal

Phil Wang 31 Oct 14, 2022