PyTorch Extension Library of Optimized Scatter Operations


PyTorch Scatter

PyPI Version Build Status Docs Status Code Coverage


This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in PyTorch, which are missing in the main package. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.

The package consists of the following operations with reduction types "sum"|"mean"|"min"|"max":

In addition, we provide the following composite functions which make use of scatter_* operations under the hood: scatter_std, scatter_logsumexp, scatter_softmax and scatter_log_softmax.

All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.



We provide pip wheels for all major OS/PyTorch/CUDA combinations, see here.

PyTorch 1.8.0

To install the binaries for PyTorch 1.8.0, simply run

pip install torch-scatter -f${CUDA}.html

where ${CUDA} should be replaced by either cpu, cu101, cu102, or cu111 depending on your PyTorch installation.

cpu cu101 cu102 cu111

PyTorch 1.7.0/1.7.1

To install the binaries for PyTorch 1.7.0 and 1.7.1, simply run

pip install torch-scatter -f${CUDA}.html

where ${CUDA} should be replaced by either cpu, cu92, cu101, cu102, or cu110 depending on your PyTorch installation.

cpu cu92 cu101 cu102 cu110

Note: Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0 and PyTorch 1.6.0 (following the same procedure).

From source

Ensure that at least PyTorch 1.5.0 is installed and verify that cuda/bin and cuda/include are in your $PATH and $CPATH respectively, e.g.:

$ python -c "import torch; print(torch.__version__)"
>>> 1.5.0

$ echo $PATH
>>> /usr/local/cuda/bin:...

$ echo $CPATH
>>> /usr/local/cuda/include:...

Then run:

pip install torch-scatter

When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail. In this case, ensure that the compute capabilities are set via TORCH_CUDA_ARCH_LIST, e.g.:

export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"


import torch
from torch_scatter import scatter_max

src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])

out, argmax = scatter_max(src, index, dim=-1)
tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])

tensor([[5, 5, 3, 4, 0, 1]
        [1, 4, 3, 5, 5, 5]])

Running tests

python test


torch-scatter also offers a C++ API that contains C++ equivalent of python models.

mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make install
  • Compiling scatter C++ API keeps using old python versions.

    Compiling scatter C++ API keeps using old python versions.

    Hi, when I am trying to compile torch scatter C++ API, I did the cmake as suggested in the readme. However, when the cmake trying to find python3, it keeps finding an old version of python. I have tried numerous ways (like add set() or inlcude_directories or -D tags when doing the cmake line), but it either keeps finding the old python3.8 or bugged out by saying it can not find a python3.10, when python3.10 is the env I configed for my applications. Would you mind providing some examples on how to modify cmakes or other suggestions so that I can force the cmake (particularly the find_package() function) to use my python3.10 to compile scatter C++ API?

    Thank you so much for the help.

    opened by ZKC19940412 2
  •  Not compatible with PyTorch 2.0 nightly builds ( next generation 2-series release of PyTorch )

    Not compatible with PyTorch 2.0 nightly builds ( next generation 2-series release of PyTorch )

    Error during compilation of extension


    1. Install Pytorch 2.0 python3 -m pip install numpy --pre torch --force-reinstall --extra-index-url
    2. Try to build using python3 -m pip install torch-scatter
    opened by sxrstudio 1
  • c++ api; scatter_sum works on kCPU, but not kCUDA

    c++ api; scatter_sum works on kCPU, but not kCUDA

    I'm trying to implement a scatter_sum operation via the c++ api.

    I'm calling the function as follows:

    results = scatter_sum(source_nodes, target_index_tensor, dim, torch::nullopt, torch::nullopt);

    I have verified that both tensors are on cuda:0 via these lines:

    std::cout << source_nodes.device() << std::endl;
    std::cout << target_index_tensor.device() << std::endl;

    The program simply fails when I used 'kCUDA' as the device, but when I use 'kCPU' as the device, it works. I have verified that the normal torch functions (linear, relu) work on the kCUDA device, so only this scatter_sum function does not go through. What could be the cause of the program failing? I simply get 'core dumped', but because it works on CPU, it's not so clear to me what could be wrong.

    Some information about the system: Python 3.8 CUDA 10.2 PyTorch 1.10

    opened by JellePiepenbrock 6
  • functorch vmap aten::scatter_add_ error

    functorch vmap aten::scatter_add_ error


    Hi 👋🏼 ,

    I would just like to start by saying, thank you for creating and maintaining this amazing library.

    When attempting to use functorch with pytorch-geometric I encountered the follwoing error related to scatter_add. Please let me know if I can provide anymore information or help out in anyway.

    Thank you, Matt


    from functorch import combine_state_for_ensemble, vmap
    from torch import nn
    from torch_geometric.nn import GCNConv
    from import Data
    import torch
    NUM_MODELS = 10
    INPUT_SIZE = 8
    # create a model
    class Model(nn.Module):
        def __init__(self, input_size: int) -> None:
            self.conv1 = GCNConv(input_size, 2, add_self_loops=False).jittable()
        def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
            return self.conv1(x, edge_index)
    # create the data
    xs = torch.randn(NUM_MODELS, NUM_NODES, INPUT_SIZE, dtype=torch.float)
    edge_indices = torch.randint(0, 3, (NUM_MODELS, 2, NUM_EDGES), dtype=torch.long)
    # create functional models
    models = [Model(INPUT_SIZE) for _ in range(NUM_MODELS)]
    fmodel, params, buffers = combine_state_for_ensemble(models)
    # complete a forward pass with the data
    res = vmap(fmodel)(params, buffers, xs, edge_indices)


    (.venv) [email protected] ~/G/torch-func [0|1]> python3
    /Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_scatter/ UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::scatter_add_. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at  /Users/runner/work/functorch/functorch/functorch/csrc/BatchedFallback.cpp:85.)
      return out.scatter_add_(dim, index, src)
    Traceback (most recent call last):
      File "/Users/matthewlemay/Github/torch-func/", line 30, in <module>
        res = vmap(fmodel)(params, buffers, xs, edge_indices)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/functorch/_src/", line 365, in wrapped
        batched_outputs = func(*batched_inputs, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch/nn/modules/", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/functorch/_src/", line 282, in forward
        return self.stateless_model(*args, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch/nn/modules/", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/", line 19, in forward
        return self.conv1(x, edge_index)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch/nn/modules/", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/var/folders/hh/vh54hqf544n7qf9vxn1lt8_00000gn/T/matthewlemay_pyg/", line 219, in forward
        edge_index, edge_weight = gcn_norm(  # yapf: disable
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_geometric/nn/conv/", line 67, in gcn_norm
        deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_scatter/", line 29, in scatter_add
        return scatter_sum(src, index, dim, out, dim_size)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_scatter/", line 21, in scatter_sum
        return out.scatter_add_(dim, index, src)
    RuntimeError: vmap: aten::scatter_add_(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over at level 1. Please try to use out-of-place operators instead of aten::scatter_add_. If said operator is being called inside the PyTorch framework, please file a bug report instead.
    opened by mplemay 2
  • `segment_csr` crashes Python when provided invalid `indptr`

    `segment_csr` crashes Python when provided invalid `indptr`

    When I run the following code:

    a = torch.arange(10)
    indptr = torch.tensor([0]) # invalid ptr
    segment_csr(a, indptr)

    Python crashes on OSX with the following message: image

    I'm on version 2.0.9 of torch_scatter. I think segment_csr should check for bad input like this.

    opened by ArchieGertsman 1
Matthias Fey
PhD student @ TU Dortmund University - Interested in Representation Learning on Graphs and Manifolds; PyTorch, CUDA, Vim and macOS Enthusiast
Matthias Fey
On the Variance of the Adaptive Learning Rate and Beyond

RAdam On the Variance of the Adaptive Learning Rate and Beyond We are in an early-release beta. Expect some adventures and rough edges. Table of Conte

Liyuan Liu 2.5k Dec 27, 2022
Fast Discounted Cumulative Sums in PyTorch

TODO: update this README! Fast Discounted Cumulative Sums in PyTorch This repository implements an efficient parallel algorithm for the computation of

Daniel Povey 7 Feb 17, 2022
PyTorch implementation of TabNet paper :

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
pip install antialiased-cnns to improve stability and accuracy

Antialiased CNNs [Project Page] [Paper] [Talk] Making Convolutional Networks Shift-Invariant Again Richard Zhang. In ICML, 2019. Quick & easy start Ru

Adobe, Inc. 1.6k Dec 28, 2022
You like pytorch? You like micrograd? You love tinygrad! ❤️

For something in between a pytorch and a karpathy/micrograd This may not be the best deep learning framework, but it is a deep learning framework. Due

George Hotz 9.7k Jan 05, 2023
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library A library of tested, GPU implementations of core structured prediction algorithms for deep learning applic

HNLP 1.1k Jan 07, 2023
3D-RETR: End-to-End Single and Multi-View3D Reconstruction with Transformers

3D-RETR: End-to-End Single and Multi-View 3D Reconstruction with Transformers (BMVC 2021) Zai Shi*, Zhao Meng*, Yiran Xing, Yunpu Ma, Roger Wattenhofe

Zai Shi 36 Dec 21, 2022
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
A very simple and small path tracer written in pytorch meant to be run on the GPU

MentisOculi Pytorch Path Tracer A very simple and small path tracer written in pytorch meant to be run on the GPU Why use pytorch and not some other c

Matthew B. Mirman 222 Dec 01, 2022
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 01, 2023