PyTorch Extension Library of Optimized Scatter Operations

Overview

PyTorch Scatter

PyPI Version Build Status Docs Status Code Coverage


Documentation

This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in PyTorch, which are missing in the main package. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.

The package consists of the following operations with reduction types "sum"|"mean"|"min"|"max":

In addition, we provide the following composite functions which make use of scatter_* operations under the hood: scatter_std, scatter_logsumexp, scatter_softmax and scatter_log_softmax.

All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.

Installation

Binaries

We provide pip wheels for all major OS/PyTorch/CUDA combinations, see here.

PyTorch 1.8.0

To install the binaries for PyTorch 1.8.0, simply run

pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+${CUDA}.html

where ${CUDA} should be replaced by either cpu, cu101, cu102, or cu111 depending on your PyTorch installation.

cpu cu101 cu102 cu111
Linux
Windows
macOS

PyTorch 1.7.0/1.7.1

To install the binaries for PyTorch 1.7.0 and 1.7.1, simply run

pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html

where ${CUDA} should be replaced by either cpu, cu92, cu101, cu102, or cu110 depending on your PyTorch installation.

cpu cu92 cu101 cu102 cu110
Linux
Windows
macOS

Note: Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0 and PyTorch 1.6.0 (following the same procedure).

From source

Ensure that at least PyTorch 1.5.0 is installed and verify that cuda/bin and cuda/include are in your $PATH and $CPATH respectively, e.g.:

$ python -c "import torch; print(torch.__version__)"
>>> 1.5.0

$ echo $PATH
>>> /usr/local/cuda/bin:...

$ echo $CPATH
>>> /usr/local/cuda/include:...

Then run:

pip install torch-scatter

When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail. In this case, ensure that the compute capabilities are set via TORCH_CUDA_ARCH_LIST, e.g.:

export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"

Example

import torch
from torch_scatter import scatter_max

src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])

out, argmax = scatter_max(src, index, dim=-1)
print(out)
tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])

print(argmax)
tensor([[5, 5, 3, 4, 0, 1]
        [1, 4, 3, 5, 5, 5]])

Running tests

python setup.py test

C++ API

torch-scatter also offers a C++ API that contains C++ equivalent of python models.

mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make
make install
Comments
  • Compiling scatter C++ API keeps using old python versions.

    Compiling scatter C++ API keeps using old python versions.

    Hi, when I am trying to compile torch scatter C++ API, I did the cmake as suggested in the readme. However, when the cmake trying to find python3, it keeps finding an old version of python. I have tried numerous ways (like add set() or inlcude_directories or -D tags when doing the cmake line), but it either keeps finding the old python3.8 or bugged out by saying it can not find a python3.10, when python3.10 is the env I configed for my applications. Would you mind providing some examples on how to modify cmakes or other suggestions so that I can force the cmake (particularly the find_package() function) to use my python3.10 to compile scatter C++ API?

    Thank you so much for the help.

    opened by ZKC19940412 2
  •  Not compatible with PyTorch 2.0 nightly builds ( next generation 2-series release of PyTorch )

    Not compatible with PyTorch 2.0 nightly builds ( next generation 2-series release of PyTorch )

    Error during compilation of extension

    Reproduce:

    1. Install Pytorch 2.0 python3 -m pip install numpy --pre torch --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
    2. Try to build using python3 -m pip install torch-scatter
    opened by sxrstudio 1
  • c++ api; scatter_sum works on kCPU, but not kCUDA

    c++ api; scatter_sum works on kCPU, but not kCUDA

    I'm trying to implement a scatter_sum operation via the c++ api.

    I'm calling the function as follows:

    results = scatter_sum(source_nodes, target_index_tensor, dim, torch::nullopt, torch::nullopt);

    I have verified that both tensors are on cuda:0 via these lines:

    std::cout << source_nodes.device() << std::endl;
    std::cout << target_index_tensor.device() << std::endl;
    

    The program simply fails when I used 'kCUDA' as the device, but when I use 'kCPU' as the device, it works. I have verified that the normal torch functions (linear, relu) work on the kCUDA device, so only this scatter_sum function does not go through. What could be the cause of the program failing? I simply get 'core dumped', but because it works on CPU, it's not so clear to me what could be wrong.

    Some information about the system: Python 3.8 CUDA 10.2 PyTorch 1.10

    opened by JellePiepenbrock 6
  • functorch vmap aten::scatter_add_ error

    functorch vmap aten::scatter_add_ error

    Overview

    Hi 👋🏼 ,

    I would just like to start by saying, thank you for creating and maintaining this amazing library.

    When attempting to use functorch with pytorch-geometric I encountered the follwoing error related to scatter_add. Please let me know if I can provide anymore information or help out in anyway.

    Thank you, Matt

    Code

    from functorch import combine_state_for_ensemble, vmap
    from torch import nn
    from torch_geometric.nn import GCNConv
    from torch_geometric.data import Data
    import torch
    
    NUM_MODELS = 10
    INPUT_SIZE = 8
    NUM_NODES, NUM_EDGES = 4, 8
    
    # create a model
    class Model(nn.Module):
        def __init__(self, input_size: int) -> None:
            super().__init__()
            self.conv1 = GCNConv(input_size, 2, add_self_loops=False).jittable()
    
        
        def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
            return self.conv1(x, edge_index)
    
    # create the data
    xs = torch.randn(NUM_MODELS, NUM_NODES, INPUT_SIZE, dtype=torch.float)
    edge_indices = torch.randint(0, 3, (NUM_MODELS, 2, NUM_EDGES), dtype=torch.long)
    
    # create functional models
    models = [Model(INPUT_SIZE) for _ in range(NUM_MODELS)]
    fmodel, params, buffers = combine_state_for_ensemble(models)
    
    # complete a forward pass with the data
    res = vmap(fmodel)(params, buffers, xs, edge_indices)
    

    Error

    (.venv) [email protected] ~/G/torch-func [0|1]> python3 run.py
    /Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_scatter/scatter.py:21: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::scatter_add_. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at  /Users/runner/work/functorch/functorch/functorch/csrc/BatchedFallback.cpp:85.)
      return out.scatter_add_(dim, index, src)
    Traceback (most recent call last):
      File "/Users/matthewlemay/Github/torch-func/run.py", line 30, in <module>
        res = vmap(fmodel)(params, buffers, xs, edge_indices)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/functorch/_src/vmap.py", line 365, in wrapped
        batched_outputs = func(*batched_inputs, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/functorch/_src/make_functional.py", line 282, in forward
        return self.stateless_model(*args, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/Users/matthewlemay/Github/torch-func/run.py", line 19, in forward
        return self.conv1(x, edge_index)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/var/folders/hh/vh54hqf544n7qf9vxn1lt8_00000gn/T/matthewlemay_pyg/tmp97j0p1uv.py", line 219, in forward
        edge_index, edge_weight = gcn_norm(  # yapf: disable
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 67, in gcn_norm
        deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_scatter/scatter.py", line 29, in scatter_add
        return scatter_sum(src, index, dim, out, dim_size)
      File "/Users/matthewlemay/Github/torch-func/.venv/lib/python3.10/site-packages/torch_scatter/scatter.py", line 21, in scatter_sum
        return out.scatter_add_(dim, index, src)
    RuntimeError: vmap: aten::scatter_add_(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over at level 1. Please try to use out-of-place operators instead of aten::scatter_add_. If said operator is being called inside the PyTorch framework, please file a bug report instead.
    
    opened by mplemay 2
  • `segment_csr` crashes Python when provided invalid `indptr`

    `segment_csr` crashes Python when provided invalid `indptr`

    When I run the following code:

    a = torch.arange(10)
    indptr = torch.tensor([0]) # invalid ptr
    segment_csr(a, indptr)
    

    Python crashes on OSX with the following message: image

    I'm on version 2.0.9 of torch_scatter. I think segment_csr should check for bad input like this.

    opened by ArchieGertsman 1
Releases(2.0.1)
Owner
Matthias Fey
PhD student @ TU Dortmund University - Interested in Representation Learning on Graphs and Manifolds; PyTorch, CUDA, Vim and macOS Enthusiast
Matthias Fey
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
PyTorch to TensorFlow Lite converter

PyTorch to TensorFlow Lite converter

Omer Ferhat Sarioglu 140 Dec 13, 2022
Training RNNs as Fast as CNNs (https://arxiv.org/abs/1709.02755)

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 01, 2023
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 01, 2022
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations

PyTorch Sparse This package consists of a small extension library of optimized sparse matrix operations with autograd support. This package currently

Matthias Fey 757 Jan 04, 2023
PyTorch toolkit for biomedical imaging

farabio is a minimal PyTorch toolkit for out-of-the-box deep learning support in biomedical imaging. For further information, see Wikis and Docs.

San Askaruly 47 Dec 28, 2022
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

205 Jan 02, 2023
You like pytorch? You like micrograd? You love tinygrad! ❤️

For something in between a pytorch and a karpathy/micrograd This may not be the best deep learning framework, but it is a deep learning framework. Due

George Hotz 9.7k Jan 05, 2023
Fast Discounted Cumulative Sums in PyTorch

TODO: update this README! Fast Discounted Cumulative Sums in PyTorch This repository implements an efficient parallel algorithm for the computation of

Daniel Povey 7 Feb 17, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 04, 2023