TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

Overview

Documents | Projects | API References

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and GPT) or huge classes (millions). It has the same API design as PyTorch.

Installation

pip install torchshard

More options in INSTALL.md.

Usage

import torchshard as ts

ts.init_process_group(group_size=2)                       # init parallel groups

m = torch.nn.Sequential(
    torch.nn.Linear(20, 30, bias=True),               
    ts.nn.ParallelLinear(30, 30, bias=True, dim=None),    # equal to nn.Linear()
    ts.nn.ParallelLinear(30, 30, bias=True, dim=0),       # parallel in row dimension
    ts.nn.ParallelLinear(30, 30, bias=True, dim=1),       # parallel in column dimension
).cuda()

x = m(x)                                                  # forward
loss = ts.nn.functional.parallel_cross_entropy(x, y)      # parallel loss function
loss.backward()                                           # backward

torch.save(
  ts.collect_state_dict(m, m.state_dict()), 'm.pt')       # save model state

Performance

The following figure is a showcase of training ResNet-50 on 8 NVIDIA TITAN-XP (12196 MiB) GPUs with scaling up classes from 1000 → 1 Million. The input size is 224 x 224, and the batch size is 256. Parallelism is with 8-way data parallel and 8-way model parallel.

The following figure shows training minGPT on 8 NVIDIA TITAN-XP (12196 MiB) GPUs with scaling up parameters from 10 Million → 808 Million. The input size is 32 x 32, and the batch size is 16. Parallelism is with 1-way data parallel and 8-way model parallel.

Contributing

The TorchShard welcomes your expertise and enthusiasm!

If you are interested in torchshard, you are welcome to help

  • polish code and develop new features
  • develop high-quality tutorials, projects, and advanced materials

Direct pull requests are welcome. Contact: kaiyuyue [at] umd.edu.

Citing TorchShard

If you think TorchShard is helpful in your research and consider to cite it, please use the following BibTeX entry.

@misc{torchshard2021,
  author =       {Kaiyu Yue},
  title =        {TorchShard},
  howpublished = {\url{https://github.com/KaiyuYue/torchshard}},
  year =         {2021}
}
Comments
  • Future Planinig on this project.

    Future Planinig on this project.

    Hello Kaiyu, I love this awesome project. The API design is elegant and simple and the software is lightweight and user-friendly. My understanding is that this project has realized a series of PyTorch wrappers for tensor slicing.

    1. I am curious about the future planning of this project.
    2. Is there some overlap in functionality between torchshard and N-D parallelism proposed in ColossalAI.
    3. How is compatibility with ZeRO? According to am+zero example, the memory footprint has a little change after combination torchshard with ZeRO.
    opened by feifeibear 2
  • Which one is faster?

    Which one is faster?

    Thanks for contributing this great lib. I have one question. Which one is faster (in speed) between dim=0and dim=1? The documentations seem to only contain accuracy results.

    opened by NOBLES5E 2
  • 8 gpus test example raise error.

    8 gpus test example raise error.

    When I do Unit Tests, it can pass when use two gpu devices, run command below: CUDA_VISIBLE_DEVICES=0,1 python3 -m unittest discover -v -s tests

    But I do Unit Tests with eight gpu devices, it raise ncclSystemError. run command: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m unittest discover -v -s tests raise error: RuntimeError: NCCL error in ../torch/lib/c10d/ProcessGroupNCCL.cpp:825, unhandled system error, NCCL version 2.7.8 ncclSystemError: System call (socket, malloc, munmap, etc) failed.

    Is it necessary to pass unittest in eights gpu devices?

    opened by JiaquanYe 1
  • Error?

    Error?

    Hi, thanks for the excellent job! When I install it from pip, and

    import torchshard as ts
    ts.init_process_group(group_size=2) 
    

    The AttributeError occurs:

    AttributeError: module 'torchshard' has no attribute 'init_process_group'
    
    opened by WangWenhao0716 1
  • Multi-node setting?

    Multi-node setting?

    https://github.com/KaiyuYue/torchshard/blob/89e21def180bf6063ceb2e312a61631173abc7e7/projects/minGPT/main.py#L150

    I have noticed that the group_size is set to world_size in examples, but in fact the group_size can be set to other numbers according to my understanding.

    https://github.com/KaiyuYue/torchshard/blob/main/torchshard/distributed/core.py#L18

    I have also found that the get_world_size() will return the number of all processes.

    The two findings make me confused in a multi-node setting, say 2 nodes with each node with 2 processes.

    If the group_size is 2, then there are 2 distinct groups besides the default group (w/ overlap). However, get_world_size() is used without specifying a group can make a layer be splitted to 4 parts, which is expected to be 2 in our case.

    Correct me if I am wrong.

    Good Issue 
    opened by GeneZC 1
  • Is it possible to collect state dict in cpu?

    Is it possible to collect state dict in cpu?

    When I finish one epoch in trianing, the main_worker function will call ts.collect_state_dict(model, state_dict). But because the limit of GPU resource, it will raise Out of Memory in my machine, when call ts.collect_state_dict(model, state_dict). I found that will gather the state_dict in GPU, is it anyway to gather in CPU?

    Good Issue 
    opened by JiaquanYe 2
Releases(v0.1)
Owner
Kaiyu Yue
Kaiyu Yue
A code copied from google-research which named motion-imitation was rewrited with PyTorch

motor-system Introduction A code copied from google-research which named motion-imitation was rewrited with PyTorch. More details can get from this pr

NewEra 6 Jan 08, 2022
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

PyTorch Implementation of Differentiable ODE Solvers This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpr

Ricky Chen 4.4k Jan 04, 2023
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
You like pytorch? You like micrograd? You love tinygrad! ❤️

For something in between a pytorch and a karpathy/micrograd This may not be the best deep learning framework, but it is a deep learning framework. Due

George Hotz 9.7k Jan 05, 2023
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and

Kaiyu Yue 275 Nov 22, 2022
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
Tutorial for surrogate gradient learning in spiking neural networks

SpyTorch A tutorial on surrogate gradient learning in spiking neural networks Version: 0.4 This repository contains tutorial files to get you started

Friedemann Zenke 203 Nov 28, 2022
Bunch of optimizer implementations in PyTorch

Bunch of optimizer implementations in PyTorch

Hyeongchan Kim 76 Jan 03, 2023
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 01, 2023
Training RNNs as Fast as CNNs (https://arxiv.org/abs/1709.02755)

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 01, 2023
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

Phil Wang 1.8k Jan 06, 2023
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

Pretrained models for Pytorch (Work in progress) The goal of this repo is: to help to reproduce research papers results (transfer learning setups for

Remi 8.7k Dec 31, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023