Deep and online learning with spiking neural networks in Python

Overview

Introduction

Documentation Status

https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_scaled.png?raw=true

The brain is the perfect place to look for inspiration to develop more efficient neural networks. One of the main differences with modern deep learning is that the brain encodes information in spikes rather than continuous activations. snnTorch is a Python package for performing gradient-based learning with spiking neural networks. It extends the capabilities of PyTorch, taking advantage of its GPU accelerated tensor computation and applying it to networks of spiking neurons. Pre-designed spiking neuron models are seamlessly integrated within the PyTorch framework and can be treated as recurrent activation units.

https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/spike_excite_alpha_ps2.gif?raw=true

snnTorch Structure

snnTorch contains the following components:

Component Description
snntorch a spiking neuron library like torch.nn, deeply integrated with autograd
snntorch.backprop variations of backpropagation commonly used with SNNs
snntorch.functional common arithmetic operations on spikes, e.g., loss, regularization etc.
snntorch.spikegen a library for spike generation and data conversion
snntorch.spikeplot visualization tools for spike-based data using matplotlib and celluloid
snntorch.spikevision contains popular neuromorphic datasets
snntorch.surrogate optional surrogate gradient functions
snntorch.utils dataset utility functions

snnTorch is designed to be intuitively used with PyTorch, as though each spiking neuron were simply another activation in a sequence of layers. It is therefore agnostic to fully-connected layers, convolutional layers, residual connections, etc.

At present, the neuron models are represented by recursive functions which removes the need to store membrane potential traces for all neurons in a system in order to calculate the gradient. The lean requirements of snnTorch enable small and large networks to be viably trained on CPU, where needed. Provided that the network models and tensors are loaded onto CUDA, snnTorch takes advantage of GPU acceleration in the same way as PyTorch.

Citation

If you find snnTorch useful in your work, please consider citing the following source:

Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu “Training Spiking Neural Networks Using Lessons From Deep Learning”. arXiv preprint arXiv:2109.12894, September 2021.

@article{eshraghian2021training,
title={Training spiking neural networks using lessons from deep learning},
author={Eshraghian, Jason K and Ward, Max and Neftci, Emre and Wang, Xinxin
and Lenz, Gregor and Dwivedi, Girish and Bennamoun, Mohammed and Jeong, Doo Seok
and Lu, Wei D},
journal={arXiv preprint arXiv:1906.09395},
year={2021}
}

Requirements

The following packages need to be installed to use snnTorch:

  • torch >= 1.1.0
  • numpy >= 1.17
  • pandas
  • matplotlib
  • math

They are automatically installed if snnTorch is installed using the pip command. Ensure the correct version of torch is installed for your system to enable CUDA compatibility.

Installation

Run the following to install:

$ python
$ pip install snntorch

To install snnTorch from source instead:

$ git clone https://github.com/jeshraghian/snnTorch
$ cd snnTorch
$ python setup.py install

API & Examples

A complete API is available here. Examples, tutorials and Colab notebooks are provided.

Quickstart

Here are a few ways you can get started with snnTorch:

Open In Colab

For a quick example to run snnTorch, see the following snippet, or test the quickstart notebook above:

import torch, torch.nn as nn
import snntorch as snn
from snntorch import surrogate

num_steps = 25 # number of time steps
batch_size = 1
beta = 0.5  # neuron decay rate
spike_grad = surrogate.fast_sigmoid()

net = nn.Sequential(
      nn.Conv2d(1, 8, 5),
      nn.MaxPool2d(2),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad),
      nn.Conv2d(8, 16, 5),
      nn.MaxPool2d(2),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad),
      nn.Flatten(),
      nn.Linear(16 * 4 * 4, 10),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad, output=True)
      )

# random input data
data_in = torch.rand(num_steps, batch_size, 1, 28, 28)

spike_recording = []

for step in range(num_steps):
    spike, state = net(data_in[step])
    spike_recording.append(spike)

If you're feeling lazy and want the training process to be taken care of:

import snntorch.functional as SF
from snntorch import backprop

# correct class should fire 80% of the time
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))

# train for one epoch using the backprop through time algorithm
# assume train_loader is a DataLoader with time-varying input
avg_loss = backprop.BPTT(net, train_loader, num_steps=num_steps,
                        optimizer=optimizer, criterion=loss_fn)

A Deep Dive into SNNs

If you wish to learn all the fundamentals of training spiking neural networks, from neuron models, to the neural code, up to backpropagation, the snnTorch tutorial series is a great place to begin. It consists of interactive notebooks with complete explanations that can get you up to speed.

Tutorial Title Colab Link
Tutorial 1 Spike Encoding with snnTorch Open In Colab
Tutorial 2 The Leaky Integrate and Fire Neuron Open In Colab
Tutorial 3 A Feedforward Spiking Neural Network Open In Colab
Tutorial 4 2nd Order Spiking Neuron Models (Optional) Open In Colab
Tutorial 5 Training Spiking Neural Networks with snnTorch Open In Colab
Tutorial 6 Surrogate Gradient Descent in a Convolutional SNN Open In Colab
Tutorial 7 Neuromorphic Datasets with Tonic + snnTorch Open In Colab

Contributing

If you're ready to contribute to snnTorch, instructions to do so can be found here.

Acknowledgments

snnTorch was initially developed by Jason K. Eshraghian in the Lu Group (University of Michigan).

Additional contributions were made by Xinxin Wang, Vincent Sun, and Emre Neftci.

Several features in snnTorch were inspired by the work of Friedemann Zenke, Emre Neftci, Doo Seok Jeong, Sumit Bam Shrestha and Garrick Orchard.

License & Copyright

snnTorch is licensed under the GNU General Public License v3.0: https://www.gnu.org/licenses/gpl-3.0.en.html.

Comments
  • Examples of regression?

    Examples of regression?

    I was wondering if anyone had used snnTorch for regression, and perhaps how you set your networks up. Just looking for simple, general examples! MSELoss would likely be the type of loss used as I see it.

    opened by shilpakancharla 29
  • snntorch-ipu crashing

    snntorch-ipu crashing

    • snntorch version: 0.5.3
    • snntorch-ipu version: 0.5.18
    • PopTorch version: 2.6.0
    • PyTorch version: 1.10.0
    • Python version: 3.8.10
    • Operating System: Ubuntu 20.04

    Description

    I've been trying to train a model in an IPU environment using PopTorch and snntorch-ipu. Unfortunately, I always get a crash. It is unclear to me what exactly is going on, so hopefully someone knows.

    What I Did

    If I try to train my model with only snntorch-ipu installed, as recommended, I will always get an error message when importing/working with surrogates about "Missing Straight Through Estimator Custom Operation file".

    /notebooks/dvsclf/network/net.py in <module>
          1 import torch
    ----> 2 from snntorch import surrogate
          3 from snntorch import utils
          4 import torch.nn as nn
          5 import numpy as np
    
    /usr/local/lib/python3.8/dist-packages/snntorch/surrogate.py in <module>
         26 
         27 
    ---> 28 class StraightThroughEstimator:
         29     """
         30     Straight Through Estimator.
    
    /usr/local/lib/python3.8/dist-packages/snntorch/surrogate.py in StraightThroughEstimator()
         53         print("Missing Straight Through Estimator Custom Operation file!")
         54         print(so_path_ste)
    ---> 55         exit(1)
         56     ctypes.cdll.LoadLibrary(so_path_ste)
         57 
    
    NameError: name 'exit' is not defined
    

    If I install snntorch (with or without snntorch-ipu beside it), I will not get the above error. Instead, something in PopTorch throws an error when the model is being trained.

    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-4-9ee252bfabe9> in <module>
          2     # Performs forward pass, loss function evaluation,
          3     # backward pass and weight update in one go on the device.
    ----> 4     _, loss = poptorch_model(batch, target)
    
    [....]
    
    /notebooks/dvsclf/network/snn.py in forward(self, x)
         20 
         21         x = self.conv(x)
    ---> 22         x = self.lif(x)
         23         return x
         24 
    
    /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
       1118             input = bw_hook.setup_input_hook(input)
       1119 
    -> 1120         result = forward_call(*input, **kwargs)
       1121         if _global_forward_hooks or self._forward_hooks:
       1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
    
    /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
       1088                 recording_scopes = False
       1089         try:
    -> 1090             result = self.forward(*input, **kwargs)
       1091         finally:
       1092             if recording_scopes:
    
    /usr/local/lib/python3.8/dist-packages/snntorch/_neurons/leaky.py in forward(self, input_, mem)
        159         if self.init_hidden:
        160             self._leaky_forward_cases(mem)
    --> 161             self.reset = self.mem_reset(self.mem)
        162             self.mem = self.state_fn(input_)
        163 
    
    /usr/local/lib/python3.8/dist-packages/snntorch/_neurons/neurons.py in mem_reset(self, mem)
         86         """Generates detached reset signal if mem > threshold.
         87         Returns reset."""
    ---> 88         mem_shift = mem - self.threshold
         89         reset = self.spike_grad(mem_shift).clone().detach()
         90 
    
    /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in __torch_function__(cls, func, types, args, kwargs)
        279                     if kwargs is None:
        280                         kwargs = {}
    --> 281                     return super().__torch_function__(func, types, args,
        282                                                       kwargs)
        283 
    
    RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
    Tensor:
    (1,1,.,.) = 
     Columns 1 to 9  0.4529  0.4529  0.4529  0.4529  0.4529  0.4529  0.4529  0.4529  0.4529 ....
    

    With self.lif = snn.Leaky(beta=0.95, spike_grad=surrogate.fast_sigmoid(), init_hidden=True, learn_beta=True, learn_threshold=True)

    opened by RoelMK 11
  • Networks Learns Nothing!

    Networks Learns Nothing!

    • snntorch version: 0.2.8
    • Python version: 3.7.4
    • Operating System: Windows 10

    Description

    Hi,

    I tried to use the snnTorch to do exactly as you are doing in Tutorial 3 (without spike_grad), and also in the upcoming Tutorial 4 (applied spike_grad to stein neuron) for Spiking CNNs. Moreover, I also converted my dataset from static to spike version as in Tutorial 4 using rate encoding.

    However, usually, I found that the network learns nothing and the training loss goes down and up slightly, but the output numbers and accuracy are completely unintuitive.

    For example, I used a Conventional CNN on my Dataset and I got 85% accuracy, and I used it with the same architecture and hyperparameters and nothing is learned.

    I will be glad if you can have some direct suggestions about what may be the problem and why the network is learning nothing in the Spiking domain?

    Looking forward to your response

    opened by Dola47 9
  • Minor Update to Dependencies list in README.rst

    Minor Update to Dependencies list in README.rst

    • snntorch version: 0.5.3
    • Python version: 3.10.4
    • Operating System: Windows 10

    Description

    I think that ffmpeg is missing from the dependencies listed in README.rst.

    What I Did

    Using conda, I installed the dependencies and then received the following error on the video conversion step of the first tutorial:

    RuntimeError: Requested MovieWriter (ffmpeg) not available
    

    This was easily sorted with: conda install -c conda-forge ffmpeg.

    opened by katywarr 6
  • latency() got an unexpected keyword argument 'num_outputs' and latency() got multiple values for argument 'num_steps'

    latency() got an unexpected keyword argument 'num_outputs' and latency() got multiple values for argument 'num_steps'

    • snntorch version: 0.2.11
    • Python version: 3.7.10

    Description

    In the latest updates, I see that the num_output has been removed from the spikegen. No clue why!

    Moreover, also when I just decide about removing the num_output parameter. I get another error from the num_steps when I set it to any int value?

    BTW, the same exists if I tried to do rate encoding instead of latency encoding.

    What I Did

    Here: you will find a quick ipynb file that shows the errors.

    opened by Dola47 5
  • Neurons can fire multiple time steps in a row.

    Neurons can fire multiple time steps in a row.

    It is possible to have neurons firing continuously using the default reset mechanism or if using RLeaky and reset to zero. The latter is due to only resetting the input but not the recurrent connections. This is undesirable behavior as it allows the neurons to essentially not be spiking neurons given the right weight values.

    opened by EvilxFish 4
  • enhance(tutorial6): net definition and link

    enhance(tutorial6): net definition and link

    • make the Net class foward method single step
    • insert link to the tutorial 5 where cited

    ### Motivation There are several points in this very useful tutorial for which I propose this ameliorative change. For clarity I will call "v1" the current version of the Net definition and "v2" the version that is proposed in this PR.

    • An error launching "Run all". This is a shape incompatibility error in the loss calculation, because in v1 there is an internal loop in .forward, which thus adds an extra axes. Now it runs all cells with no errors.
    • .forward in v1 is apparently inconsistent with .forward in v2, being the former with an inner loop (so multi-step), while the latter without it (so single-step).
    • (am I overthinking?) Whether there was an intention behind defining the network with an inner loop to define a smaller dt, characteristic of an inner frequency higher than the outer frequency (receptors) is not very clear and it might be beneficial to make it explicit.

    It is clear that the tutorial 5 example was being cited, but in this context it comes across as a bit inconsistent. If I have misunderstood, please tell me how I should best understand the tutorial.

    Thank you for the effort of creating this project :)

    opened by gianfa 4
  • Detach and Reset Spikes in RLeaky

    Detach and Reset Spikes in RLeaky

    I noticed that for the Rleaky neurons the spike acts as an internal state but was not reset nor detached. I believe the spikes should be reset and detached similar to the membrane potential in reset_hidden and detach_hidden, respectively. This merge request adds support for resetting and detaching of the spikes.

    opened by manuelbre 4
  • Issue with inputing custom weights for Rate based SNN

    Issue with inputing custom weights for Rate based SNN

    • snntorch version: 0.4.4
    • Python version: 3.9.6
    • Operating System: Ubuntu 20.04.3 LTS

    Description

    Hi Jason, First of all, I appreciate your wonderful effort in developing this package and a detailed documentation. I have recently started using snntorch for rate based SNN coding. Although I am getting good performance for purely software based run, I am facing issues with inputing custom weights extracted from a synaptic device. My accuracy is getting stuck at around 10% which is the same as the untrained accuracy.

    What I Did

    I used a custom function to input the weights from a text file as shown in the screenshot. Please let me know how to solve this issue. NB: I am pretty new to programming. so pls excuse me if my code is too cumbersome :) Capture ` Here is the full file and the text file for data input rate_SNN_dev_weights.zip

    Thanks, Kannan

    opened by kannanum 4
  • RSynaptic Neuron Model NOT in snnTorch

    RSynaptic Neuron Model NOT in snnTorch

    Hello, I am trying to use the RSynaptic neuron model but even though it is documented in the snnTorch website, the class RSynaptic(LIF) is not included in the snntroch module. This this the error I get when I try to use it:

    AttributeError: module 'snntorch' has no attribute 'RSynaptic'

    opened by msbouanane 3
  • Tonic example

    Tonic example

    Added example notebook that shows how to download data, ready to feed to network. Tonic does support batching for tensor representations, if you want to do that please let me know and I'll add it.

    opened by biphasic 3
  • TBPTT mode gives error about K parameter on time varying data

    TBPTT mode gives error about K parameter on time varying data

    • snntorch version: 0.5.3
    • Python version: 3.8
    • Operating System: windows

    when i try to use TBPTT or RTRL from backprop on time varying signal, it gives me this error:

    Java Printing.pdf

        if K_flag is False:
    UnboundLocalError: local variable 'K_flag' referenced before assignment
    
    opened by alisam1992 0
  • Add power profiling capabilities

    Add power profiling capabilities

    This seems to be a super popular feature request. Making accurate estimates seems near impossible, but we can probably generate an order of magnitude guess here.

    The user would construct a model, pass data in, and the power profiling function returns the number of Synaptic operations in the forward-pass (this could be averaged across batches).

    Each synaptic op would be scaled by the energy cost for all selected devices; e.g., various GPUs & neuromorphic hardware. The same number would be given for non-spiking networks too. This could be achieved by just removing the spiking modules.

    SpikingKeras has a similar function that does it really nicely. However, it overstates the improvement given with spikes because it does not account for overhead (i.e., moving data to/from memory, or between multiple chips).

    Including an argument that factors in overhead would by tricky, but useful. The model would be parsed for number of neurons/synapses, and if either exceeds the bandwidth of a single chip, then we need to estimate how frequently data needs to be moved between chips & add that to the overall energy consumption.

    A lot of coarse estimates would be made, but I think it could be helpful.

    enhancement 
    opened by jeshraghian 0
  • snntorch multi GPU training issue

    snntorch multi GPU training issue

    • snntorch version: snntorch: 0.5.3
    • Python version: 3.9
    • Operating System: linux
    • nvidia-smi
    Every 0.5s: nvidia-smi                                                                                                                                    neuro: Fri Dec  2 11:16:53 2022
    
    Fri Dec  2 11:16:53 2022
    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |                               |                      |               MIG M. |
    |===============================+======================+======================|
    |   0  NVIDIA RTX A6000    On   | 00000000:1B:00.0 Off |                  Off |
    | 30%   30C    P8    29W / 300W |      1MiB / 48682MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   1  NVIDIA RTX A6000    On   | 00000000:1C:00.0 Off |                  Off |
    | 30%   27C    P8    22W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   2  NVIDIA RTX A6000    On   | 00000000:1D:00.0 Off |                  Off |
    | 30%   32C    P8    23W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   3  NVIDIA RTX A6000    On   | 00000000:1E:00.0 Off |                  Off |
    | 30%   31C    P8    23W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   4  NVIDIA RTX A6000    On   | 00000000:3D:00.0 Off |                  Off |
    | 30%   27C    P8    22W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   5  NVIDIA RTX A6000    On   | 00000000:3F:00.0 Off |                  Off |
    | 30%   29C    P8    23W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   6  NVIDIA RTX A6000    On   | 00000000:40:00.0 Off |                  Off |
    | 30%   27C    P8    22W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   7  NVIDIA RTX A6000    On   | 00000000:41:00.0 Off |                  Off |
    | 30%   30C    P8    22W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    
    +-----------------------------------------------------------------------------+
    | Processes:                                                                  |
    |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
    |        ID   ID                                                   Usage      |
    |=============================================================================|
    |  No running processes found                                                 |
    +-----------------------------------------------------------------------------+
    
    

    Description

    I'm trying to train NMNIST with snntorch using multi GPU. since snntorch is based on torch package, I thought data parrallel from torch nn should work. here's whole code.

    import torch
    import torchvision.datasets as dsets
    import torchvision.transforms as transforms
    import torch.nn.init
    import os
    import torch.nn as nn
    import time
    import matplotlib.pyplot as plt
    import tonic.transforms as transforms
    import tonic
    import numpy as np
    import snntorch as snn
    from snntorch import surrogate
    from snntorch import functional as SF
    from snntorch import spikeplot as splt
    from snntorch import utils
    import torch.nn as nn
    import os
    from torch.utils.data import DataLoader, random_split
    import torch
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    sensor_size = tonic.datasets.NMNIST.sensor_size
    
    # Denoise removes isolated, one-off events
    # time_window
    frame_transform = transforms.ToFrame(sensor_size=sensor_size, time_window=1)
    
    
    frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),
                                          transforms.ToFrame(sensor_size=sensor_size,
                                                             time_window=50000)
                                         ])
    
    trainset = tonic.datasets.NMNIST(save_to='/home/hubo1024/PycharmProjects/snntorch/data/NMNIST', transform=frame_transform, train=True)
    testset = tonic.datasets.NMNIST(save_to='./home/hubo1024/PycharmProjects/snntorch/data/NMNIST', transform=frame_transform, train=False)
    
    # seed fix
    torch.manual_seed(777)
    
    # seed fix if gpu is available
    if device == 'cuda':
        torch.cuda.manual_seed_all(777)
    
    #batch_size = 100
    
    batch_size = 32
    dataset_size = len(trainset)
    train_size = int(dataset_size * 0.9)
    validation_size = int(dataset_size * 0.1)
    
    
    trainset, valset = random_split(trainset, [train_size, validation_size])
    print(len(valset))
    print(len(trainset))
    trainloader = DataLoader(trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(), shuffle=True)
    valloader = DataLoader(valset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(), shuffle=True)
    testloader = DataLoader(testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors())
    
    
    spike_grad = surrogate.fast_sigmoid(slope=75)
    beta = 0.5
    
    class CNN(torch.nn.Module):
    
        def __init__(self):
            super(CNN, self).__init__()
            self.keep_prob = 0.5
            self.layer1 = torch.nn.Sequential(
                nn.Conv2d(2, 12, 5),
                nn.MaxPool2d(2),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
            )
    
            self.layer2 = torch.nn.Sequential(
                nn.Conv2d(12, 32, 5),
                nn.MaxPool2d(2),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
            )
    
            self.layer4 = torch.nn.Sequential(
                nn.Flatten(),
                nn.Linear(32 * 5 * 5, 10),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
            )
    
        def forward(self, data):
            spk_rec = []
            layer1_rec = []
            layer2_rec = []
            utils.reset(self.layer1)  # resets hidden states for all LIF neurons in net
            utils.reset(self.layer2)
            utils.reset(self.layer4)
    
            for step in range(data.size(1)):  # data.size(0) = number of time steps
                input_torch = data[:, step, :, :, :]
                input_torch = input_torch.cuda()
                #print(input_torch)
                out = self.layer1(input_torch)
                #out1 = out
    
                out = self.layer2(out)
                #out2 = out
                out, mem = self.layer4(out)
                #out = self.layer4(out)
    
                spk_rec.append(out)
    
                #layer1_rec.append(out1)
                #layer2_rec.append(out2)
    
            return torch.stack(spk_rec)#, torch.stack(layer1_rec), torch.stack(layer2_rec)
    
    
    model = CNN().to(device)
    device_ids = [0, 1] #your GPU index
    model = torch.nn.DataParallel(model, device_ids=device_ids)
    #model = nn.DataParallel(model).to(device)
    optimizer = torch.optim.NAdam(model.parameters(), lr=0.005,betas=(0.9, 0.999))
    loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
    #model = nn.DataParallel(model)
    
    total_batch = len(trainloader)
    print('총 배치의 수 : {}'.format(total_batch))
    loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
    num_epochs = 15
    loss_hist = []
    acc_hist = []
    v_acc_hist = []
    t_spk_rec_sum = []
    start = time.time()
    val_cnt = 0
    v_acc_sum= 0
    avg_loss = 0
    index = 0
    #################################################
    
    
    for epoch in range(num_epochs):
        torch.save(model.state_dict(), '/home/hubo1024/PycharmProjects/snntorch/model_pt/Radam_15epoch-50000.pt')
        for i, (data, targets) in enumerate(iter(trainloader)):
            data = data.cuda()
            targets = targets.cuda()
            model.train()
    
            spk_rec = model(data)
    
            #print(spk_rec.shape)
            loss_val = loss_fn(spk_rec, targets)
            avg_loss += loss_val.item()
            optimizer.zero_grad()
    
            loss_val.backward()
    
            optimizer.step()
    
            # Store loss history for future plotting
            loss_hist.append(loss_val.item())
            val_cnt = val_cnt+1
            #del loss_val
    
    
            if val_cnt == len(trainloader)/2-1:
                val_cnt=0
    
                for ii, (v_data, v_targets) in enumerate(iter(valloader)):
                    v_data = v_data.to(device)
                    v_targets = v_targets.to(device)
    
                    v_spk_rec = model(v_data)
                    #
                    # print(t_spk_rec.shape)
                    v_acc = SF.accuracy_rate(v_spk_rec, v_targets)
                    del v_spk_rec
                    if ii == 0:
                        v_acc_sum = v_acc
                        cnt = 1
    
                    else:
                        v_acc_sum += v_acc
                        cnt += 1
                    #del v_acc
    
    
                plt.plot(acc_hist)
                plt.plot(v_acc_hist)
                plt.legend(['train accuracy', 'validation accuracy'])
                plt.title("Train, Validation Accuracy-Radam 15epoch-50000")
                plt.xlabel("Iteration")
                plt.ylabel("Accuracy")
                # plt.show()
                plt.savefig('Radam_15epoch-50000.png')
                plt.clf()
                v_acc_sum = v_acc_sum/cnt
    
    
                # avg_loss = avg_loss / (len(trainloader) / 2)
                # print('average loss while half epoch', avg_loss)
                # if avg_loss <= 0.5:
                #     index = 1
                #     break
                # else:
                #     avg_loss = 0
                #     index = 0
    
            print('Radam-15epoch-50000')
            print("time :", time.time() - start,"sec")
            print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")
    
            acc = SF.accuracy_rate(spk_rec, targets)
            acc_hist.append(acc)
            v_acc_hist.append(v_acc_sum)
            print(f"Train Accuracy: {acc * 100:.2f}%")
            print(f"Validation Accuracy: {v_acc_sum * 100:.2f}%\n")
    
        #     if index == 1:
        #         break
        # if index == 1:
        #     break
    # 학습을 진행하지 않을 것이므로 torch.no_grad()
    '''
    with torch.no_grad():
        X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)
        Y_test = mnist_test.test_labels.to(device)
    
        prediction = model(X_test)
        correct_prediction = torch.argmax(prediction, 1) == Y_test
        accuracy = correct_prediction.float().mean()
        print('Accuracy:', accuracy.item())
    '''
    
    

    and here's error

    (snn_torch) [email protected]:~/PycharmProjects/snntorch$ CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python gpu_6_run.py
    6000
    54000
    총 배치의 수 : 13500
    Traceback (most recent call last):
      File "/home/hubo1024/PycharmProjects/snntorch/gpu_6_run.py", line 146, in <module>
        spk_rec = model(data)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
        output.reraise()
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraise
        raise exception
    RuntimeError: Caught RuntimeError in replica 0 on device 0.
    Original Traceback (most recent call last):
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
        output = module(*input, **kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/hubo1024/PycharmProjects/snntorch/gpu_6_run.py", line 102, in forward
        out = self.layerconv1(input_torch)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
        input = module(input)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 162, in forward
        self.mem = self.state_fn(input_)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 201, in _build_state_function_hidden
        self._base_state_function_hidden(input_) - self.reset * self.threshold
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 195, in _base_state_function_hidden
        base_fn = self.beta.clamp(0, 1) * self.mem + input_
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/_tensor.py", line 1121, in __torch_function__
        ret = func(*args, **kwargs)
    RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
    
    

    I rerun this code after removing snn.Leaky layer in CNN and it worked fine. (of course the cost doesn't converge and accuracy was 0% but still it runs) So I assume that the reason of this error is snn.Leaky layer. I think changing

    opened by rkdgmlqja 5
  • Clean .flake8

    Clean .flake8

    The .flake8config currently excludes lots of hints and errors. The code-base should be cleaned with the standard Flake8 config.

    I would be glad to help out. Reporting as a reminder.

    opened by ahenkes1 1
  • Add class imbalance weighting to loss functions

    Add class imbalance weighting to loss functions

    Apply on/off target weighting to snntorch.functional losses in the same way the PyTorch enables weighting.

    Cross Entropy-based losses should be straightforward; Mean Square Error Losses are less trivial.

    enhancement 
    opened by jeshraghian 0
Releases(v0.5.2)
  • v0.5.2(Aug 4, 2022)

    What's Changed

    • leaky and rleaky state function substract function fix by @pengzhouzp in https://github.com/jeshraghian/snntorch/pull/95
    • Detach and Reset Spikes in RLeaky by @manuelbre in https://github.com/jeshraghian/snntorch/pull/108
    • Integrate ATan Surrogate function. by @ridgerchu in https://github.com/jeshraghian/snntorch/pull/111
    • bptt bug may trigger device inconsistency. by @ridgerchu in https://github.com/jeshraghian/snntorch/pull/115
    • Add a new feature 'probe' by @ridgerchu in https://github.com/jeshraghian/snntorch/pull/117

    New Contributors

    • @pengzhouzp made their first contribution in https://github.com/jeshraghian/snntorch/pull/95
    • @manuelbre made their first contribution in https://github.com/jeshraghian/snntorch/pull/108
    • @ridgerchu made their first contribution in https://github.com/jeshraghian/snntorch/pull/111
    • @MegaYEye made their first contribution in https://github.com/jeshraghian/snntorch/pull/118

    Full Changelog: https://github.com/jeshraghian/snntorch/compare/v0.5.1...v0.5.2

    Source code(tar.gz)
    Source code(zip)
  • v0.5.0(Feb 10, 2022)

    What's new?

    • refactored structure of neuron models to make it easier to integrate custom neurons
    • added recurrent Leaky neuron RLeaky
    • added recurrent Synaptic neuron RSynaptic
    • Spiking LSTM neurons added SLSTM
    • Spiking Convolutional 2d LSTMs added SConv2dLSTM
    • learnable thresholds for all neurons
    • learnable explicit recurrence
    • Reset mechanism now includes 'none' as an option
    • update unit tests

    snntorch.surrogate

    • Triangular surrogate
    • Straight through estimator

    snntorch.functional

    • mse_temporal_loss function added Applies mean square error the first F spikes. Option for tolerance included, as well as passing labels to be converted into spike-time targets.

    • ce_temporal_loss added Applies cross entropy loss to an inversion of the first spike. Inversion options include -1 * x and 1/x which means maximizing the logit of the correct class corresponds to minimizing the correct neuron's firing time.

    • accuracy_temporal added Measures accuracy based on the occurrence of the first spike

    Full Changelog: https://github.com/jeshraghian/snntorch/compare/v0.4.11...v0.5.0

    Source code(tar.gz)
    Source code(zip)
  • v0.2.11(May 17, 2021)

    Some of the bugs from the previous versions have now been fixed w.r.t. sizes of tensors in spike encoding.

    What's new?

    snntorch.spikegen

    • Data & target conversion have been separated out
    • Conversion sizes have been fixed
    • Time dimension is only created if tensor is time-varying (i.e., latency will always have time-dimension; rate might not)
    • Latency & rate target conversion
    • interpolation, on/off spike vals, time to first spike, on/off rate options included

    snntorch.surrogate

    • Parameterization of surrogate gradients has been removed from global variable to local variables within closures
    • Spike operator (1/u)
    • Leaky Local spike operator (leaky relu shifted equivalent)
    • Local stochastic spike operator
    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Feb 27, 2021)

    Some of the bugs from the previous versions have now been fixed.

    What's new?

    snntorch

    • SRM0 neuron model fix
    • Reset now applies the threshold rather than '1'
    • Reset by subtraction and reset to zero methods applied to both Stein and SRM0 neurons

    snntorch.spikegen

    • Delta modulation

    snntorch.surrogate

    • Optimized grad calculation

    dev notes

    • Travis-CI is no longer free. Replaced travis.yml with GH actions integration + tox
    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Feb 11, 2021)

    The first functional iteration of snnTorch!

    What's new?

    snntorch The workhorse of the package. All neuron models are integrated here, and a default Heaviside gradient is used to override the non-differentiability with conventional autograd methods in PyTorch.

    • Stein's neuron model
    • SRM0 neuron model
    • firing inhibition, thanks to @xxwang1
    • hidden states can optionally be initialized as instance variables if the user wants to just use a built-in backprop method

    snntorch.backprop

    • Backprop through time (BPTT)
    • Truncated backprop through time (TBPTT)
    • Real-time recurrent learning (RTRL)

    snntorch.spikegen

    • Poisson spike train generator
    • Rate coding
    • Latency coding

    snntorch.surrogate

    • FastSigmoid
    • Sigmoid
    • Spike Rate Escape

    snntorch.spikeplot

    • Raster plots
    • Feature map animator
    • Spike count animator

    snntorch.utils

    • Data split
    • Data reduction

    Plans for alpha-2

    • delta & delta-sigma spike generators for snntorch.spikegen
    • Simplified Stein's model (reduce hidden states from 2 to 1)
    • More surrogate and backprop methods
    • add more tests
    Source code(tar.gz)
    Source code(zip)
Owner
Jason Eshraghian
neuromorphic engineer
Jason Eshraghian
vit for few-shot classification

Few-Shot ViT Requirements PyTorch (= 1.9) TorchVision timm (latest) einops tqdm numpy scikit-learn scipy argparse tensorboardx Pretrained Checkpoints

Martin Dong 26 Nov 30, 2022
Neural Style and MSG-Net

PyTorch-Style-Transfer This repo provides PyTorch Implementation of MSG-Net (ours) and Neural Style (Gatys et al. CVPR 2016), which has been included

Hang Zhang 904 Dec 21, 2022
Normalization Calibration (NorCal) for Long-Tailed Object Detection and Instance Segmentation

NorCal Normalization Calibration (NorCal) for Long-Tailed Object Detection and Instance Segmentation On Model Calibration for Long-Tailed Object Detec

Tai-Yu (Daniel) Pan 24 Dec 25, 2022
Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR)

This is the official implementation of our paper Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR), which has been accepted by WSDM2022.

Yongchun Zhu 81 Dec 29, 2022
Official implementations of PSENet, PAN and PAN++.

News (2021/11/03) Paddle implementation of PAN, see Paddle-PANet. Thanks @simplify23. (2021/04/08) PSENet and PAN are included in MMOCR. Introduction

395 Dec 14, 2022
This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

Quinn Herden 1 Feb 04, 2022
This is the official pytorch implementation of AutoDebias, an automatic debiasing method for recommendation.

AutoDebias This is the official pytorch implementation of AutoDebias, a debiasing method for recommendation system. AutoDebias is proposed in the pape

Dong Hande 77 Nov 25, 2022
Trajectory Prediction with Graph-based Dual-scale Context Fusion

DSP: Trajectory Prediction with Graph-based Dual-scale Context Fusion Introduction This is the project page of the paper Lu Zhang, Peiliang Li, Jing C

HKUST Aerial Robotics Group 103 Jan 04, 2023
An implementation of the proximal policy optimization algorithm

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

Martin Huber 59 Dec 09, 2022
Universal Probability Distributions with Optimal Transport and Convex Optimization

Sylvester normalizing flows for variational inference Pytorch implementation of Sylvester normalizing flows, based on our paper: Sylvester normalizing

Rianne van den Berg 172 Dec 13, 2022
Readings for "A Unified View of Relational Deep Learning for Polypharmacy Side Effect, Combination Therapy, and Drug-Drug Interaction Prediction."

Polypharmacy - DDI - Synergy Survey The Survey Paper This repository accompanies our survey paper A Unified View of Relational Deep Learning for Polyp

AstraZeneca 79 Jan 05, 2023
Self-supervised Deep LiDAR Odometry for Robotic Applications

DeLORA: Self-supervised Deep LiDAR Odometry for Robotic Applications Overview Paper: link Video: link ICRA Presentation: link This is the correspondin

Robotic Systems Lab - Legged Robotics at ETH Zürich 181 Dec 29, 2022
Information Gain Filtration (IGF) is a method for filtering domain-specific data during language model finetuning. IGF shows significant improvements over baseline fine-tuning without data filtration.

Information Gain Filtration Information Gain Filtration (IGF) is a method for filtering domain-specific data during language model finetuning. IGF sho

4 Jul 28, 2022
The Official Repository for "Generalized OOD Detection: A Survey"

Generalized Out-of-Distribution Detection: A Survey 1. Overview This repository is with our survey paper: Title: Generalized Out-of-Distribution Detec

Jingkang Yang 338 Jan 03, 2023
The Rich Get Richer: Disparate Impact of Semi-Supervised Learning

The Rich Get Richer: Disparate Impact of Semi-Supervised Learning Preprocess file of the dataset used in implicit sub-populations: (Demographic groups

<a href=[email protected]"> 4 Oct 14, 2022
Adversarial Attacks are Reversible via Natural Supervision

Adversarial Attacks are Reversible via Natural Supervision ICCV2021 Citation @InProceedings{Mao_2021_ICCV, author = {Mao, Chengzhi and Chiquier

Computer Vision Lab at Columbia University 20 May 22, 2022
Multi-Glimpse Network With Python

Multi-Glimpse Network Our code requires Python ≥ 3.8 Installation For example, venv + pip: $ python3 -m venv env $ source env/bin/activate (env) $ pyt

9 May 10, 2022
Implementation of MA-Trace - a general-purpose multi-agent RL algorithm for cooperative environments.

Off-Policy Correction For Multi-Agent Reinforcement Learning This repository is the official implementation of Off-Policy Correction For Multi-Agent R

4 Aug 18, 2022
Image Segmentation and Object Detection in Pytorch

Image Segmentation and Object Detection in Pytorch Pytorch-Segmentation-Detection is a library for image segmentation and object detection with report

Daniil Pakhomov 732 Dec 10, 2022
Motion and Shape Capture from Sparse Markers

MoSh++ This repository contains the official chumpy implementation of mocap body solver used for AMASS: AMASS: Archive of Motion Capture as Surface Sh

Nima Ghorbani 135 Dec 23, 2022