PyTorch Extension Library of Optimized Scatter Operations

Overview

PyTorch Scatter

PyPI Version Build Status Docs Status Code Coverage


Documentation

This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in PyTorch, which are missing in the main package. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.

The package consists of the following operations with reduction types "sum"|"mean"|"min"|"max":

In addition, we provide the following composite functions which make use of scatter_* operations under the hood: scatter_std, scatter_logsumexp, scatter_softmax and scatter_log_softmax.

All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.

Installation

Binaries

We provide pip wheels for all major OS/PyTorch/CUDA combinations, see here.

PyTorch 1.8.0

To install the binaries for PyTorch 1.8.0, simply run

pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+${CUDA}.html

where ${CUDA} should be replaced by either cpu, cu101, cu102, or cu111 depending on your PyTorch installation.

cpu cu101 cu102 cu111
Linux
Windows
macOS

PyTorch 1.7.0/1.7.1

To install the binaries for PyTorch 1.7.0 and 1.7.1, simply run

pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html

where ${CUDA} should be replaced by either cpu, cu92, cu101, cu102, or cu110 depending on your PyTorch installation.

cpu cu92 cu101 cu102 cu110
Linux
Windows
macOS

Note: Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0 and PyTorch 1.6.0 (following the same procedure).

From source

Ensure that at least PyTorch 1.5.0 is installed and verify that cuda/bin and cuda/include are in your $PATH and $CPATH respectively, e.g.:

$ python -c "import torch; print(torch.__version__)"
>>> 1.5.0

$ echo $PATH
>>> /usr/local/cuda/bin:...

$ echo $CPATH
>>> /usr/local/cuda/include:...

Then run:

pip install torch-scatter

When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail. In this case, ensure that the compute capabilities are set via TORCH_CUDA_ARCH_LIST, e.g.:

export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"

Example

import torch
from torch_scatter import scatter_max

src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])

out, argmax = scatter_max(src, index, dim=-1)
print(out)
tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])

print(argmax)
tensor([[5, 5, 3, 4, 0, 1]
        [1, 4, 3, 5, 5, 5]])

Running tests

python setup.py test

C++ API

torch-scatter also offers a C++ API that contains C++ equivalent of python models.

mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make
make install
Comments
  • Compiling scatter C++ API keeps using old python versions.

    Compiling scatter C++ API keeps using old python versions.

    Hi, when I am trying to compile torch scatter C++ API, I did the cmake as suggested in the readme. However, when the cmake trying to find python3, it keeps finding an old version of python. I have tried numerous ways (like add set() or inlcude_directories or -D tags when doing the cmake line), but it either keeps finding the old python3.8 or bugged out by saying it can not find a python3.10, when python3.10 is the env I configed for my applications. Would you mind providing some examples on how to modify cmakes or other suggestions so that I can force the cmake (particularly the find_package() function) to use my python3.10 to compile scatter C++ API?

    Thank you so much for the help.

    opened by ZKC19940412 2
  •  Not compatible with PyTorch 2.0 nightly builds ( next generation 2-series release of PyTorch )

    Not compatible with PyTorch 2.0 nightly builds ( next generation 2-series release of PyTorch )

    Error during compilation of extension

    Reproduce:

    1. Install Pytorch 2.0 python3 -m pip install numpy --pre torch --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
    2. Try to build using python3 -m pip install torch-scatter
    opened by sxrstudio 1
  • c++ api; scatter_sum works on kCPU, but not kCUDA

    c++ api; scatter_sum works on kCPU, but not kCUDA

    I'm trying to implement a scatter_sum operation via the c++ api.

    I'm calling the function as follows:

    results = scatter_sum(source_nodes, target_index_tensor, dim, torch::nullopt, torch::nullopt);

    I have verified that both tensors are on cuda:0 via these lines:

    std::cout << source_nodes.device() << std::endl;
    std::cout << target_index_tensor.device() << std::endl;
    

    The program simply fails when I used 'kCUDA' as the device, but when I use 'kCPU' as the device, it works. I have verified that the normal torch functions (linear, relu) work on the kCUDA device, so only this scatter_sum function does not go through. What could be the cause of the program failing? I simply get 'core dumped', but because it works on CPU, it's not so clear to me what could be wrong.

    Some information about the system: Python 3.8 CUDA 10.2 PyTorch 1.10

    opened by JellePiepenbrock 6
  • functorch vmap aten::scatter_add_ error

    functorch vmap aten::scatter_add_ error

    Overview

    Hi 👋🏼 ,

    I would just like to start by saying, thank you for creating and maintaining this amazing library.

    When attempting to use functorch with pytorch-geometric I encountered the follwoing error related to scatter_add. Please let me know if I can provide anymore information or help out in anyway.

    Thank you, Matt

    Code

    from functorch import combine_state_for_ensemble, vmap
    from torch import nn
    from torch_geometric.nn import GCNConv
    from torch_geometric.data import Data
    import torch
    
    NUM_MODELS = 10
    INPUT_SIZE = 8
    NUM_NODES, NUM_EDGES = 4, 8
    
    # create a model
    class Model(nn.Module):
        def __init__(self, input_size: int) -> None:
            super().__init__()
            self.conv1 = GCNConv(input_size, 2, add_self_loops=False).jittable()
    
        
        def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
            return self.conv1(x, edge_index)
    
    # create the data
    xs = torch.randn(NUM_MODELS, NUM_NODES, INPUT_SIZE, dtype=torch.float)
    edge_indices = torch.randint(0, 3, (NUM_MODELS, 2, NUM_EDGES), dtype=torch.long)
    
    # create functional models
    models = [Model(INPUT_SIZE) for _ in range(NUM_MODELS)]
    fmodel, params, buffers = combine_state_for_ensemble(models)
    
    # complete a forward pass with the data
    res = vmap(fmodel)(params, buffers, xs, edge_indices)
    

    Error

    (.venv) [email protected] ~/G/torch-func [0|1]> python3 run.py
    /Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_scatter/scatter.py:21: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::scatter_add_. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at  /Users/runner/work/functorch/functorch/functorch/csrc/BatchedFallback.cpp:85.)
      return out.scatter_add_(dim, index, src)
    Traceback (most recent call last):
      File "/Users/matthewlemay/Github/torch-func/run.py", line 30, in <module>
        res = vmap(fmodel)(params, buffers, xs, edge_indices)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/functorch/_src/vmap.py", line 365, in wrapped
        batched_outputs = func(*batched_inputs, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/functorch/_src/make_functional.py", line 282, in forward
        return self.stateless_model(*args, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/run.py", line 19, in forward
        return self.conv1(x, edge_index)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/var/folders/hh/vh54hqf544n7qf9vxn1lt8_00000gn/T/matthewlemay_pyg/tmp97j0p1uv.py", line 219, in forward
        edge_index, edge_weight = gcn_norm(  # yapf: disable
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 67, in gcn_norm
        deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_scatter/scatter.py", line 29, in scatter_add
        return scatter_sum(src, index, dim, out, dim_size)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_scatter/scatter.py", line 21, in scatter_sum
        return out.scatter_add_(dim, index, src)
    RuntimeError: vmap: aten::scatter_add_(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over at level 1. Please try to use out-of-place operators instead of aten::scatter_add_. If said operator is being called inside the PyTorch framework, please file a bug report instead.
    
    opened by mplemay 2
  • `segment_csr` crashes Python when provided invalid `indptr`

    `segment_csr` crashes Python when provided invalid `indptr`

    When I run the following code:

    a = torch.arange(10)
    indptr = torch.tensor([0]) # invalid ptr
    segment_csr(a, indptr)
    

    Python crashes on OSX with the following message: image

    I'm on version 2.0.9 of torch_scatter. I think segment_csr should check for bad input like this.

    opened by ArchieGertsman 1
Releases(2.0.1)
Owner
Matthias Fey
PhD student @ TU Dortmund University - Interested in Representation Learning on Graphs and Manifolds; PyTorch, CUDA, Vim and macOS Enthusiast
Matthias Fey
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

878 Dec 30, 2022
torch-optimizer -- collection of optimizers for Pytorch

torch-optimizer torch-optimizer -- collection of optimizers for PyTorch compatible with optim module. Simple example import torch_optimizer as optim

Nikolay Novik 2.6k Jan 03, 2023
PyTorch wrappers for using your model in audacity!

PyTorch wrappers for using your model in audacity!

130 Dec 14, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

Henrique Morimitsu 105 Dec 16, 2022
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

Phil Wang 1.8k Jan 06, 2023
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
Training RNNs as Fast as CNNs (https://arxiv.org/abs/1709.02755)

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 01, 2023
Code snippets created for the PyTorch discussion board

PyTorch misc Collection of code snippets I've written for the PyTorch discussion board. All scripts were testes using the PyTorch 1.0 preview and torc

461 Dec 26, 2022
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
A PyTorch implementation of EfficientNet

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 06, 2023
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 03, 2023
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision

🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.

Hugging Face 3.5k Jan 08, 2023