TimeSHAP explains Recurrent Neural Network predictions.

Overview

TimeSHAP

TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP and extends it to the sequential domain. TimeSHAP computes event/timestamp- feature-, and cell-level attributions. As sequences can be arbitrarily long, TimeSHAP also implements a pruning algorithm based on Shapley Values, that finds a subset of consecutive, recent events that contribute the most to the decision.

This repository is the code implementation of the TimeSHAP algorithm present in the paper TimeSHAP: Explaining Recurrent Models through Sequence Perturbations published at KDD 2021.

Links to the paper here, and to the video presentation here.

Install TimeSHAP

Clone the repository into a local directory using:

git clone https://github.com/feedzai/timeshap.git

Install the package using pip:

pip install timeshap

To test your installation, start a Python session in your terminal using

python

And import TimeSHAP

import timeshap

TimeSHAP in 30 seconds

Inputs

  • Model being explained;
  • Instance(s) to explain;
  • Background instance.

Outputs

  • Local pruning output; (explaining a single instance)
  • Local event explanations; (explaining a single instance)
  • Local feature explanations; (explaining a single instance)
  • Global pruning statistics; (explaining multiple instances)
  • Global event explanations; (explaining multiple instances)
  • Global feature explanations; (explaining multiple instances)

Model Interface

In order for TimeSHAP to explain a model, an entry point must be provided. This Callable entry point must receive a 3-D numpy array, (#sequences; #sequence length; #features) and return a 2-D numpy array (#sequences; 1) with the corresponding score of each sequence. In addition, to make TimeSHAP more optimized, it is possible to return the hidden state of the model together with the score (if applicable), although this is optional.

TimeSHAP is able to explain any black-box model as long as it complies with the previously described interface, including both PyTorch and TensorFlow models, both examplified in our tutorials (PyTorch, TensorFlow).

Example provided in our tutorials:

  • TensorFLow
model = tf.keras.models.Model(inputs=inputs, outputs=ff2)
f = lambda x: model.predict(x)
  • Pytorch - (Example where model receives and returns hidden states)
model_wrapped = TorchModelWrapper(model)
f_hs = lambda x, y=None: model_wrapped.predict_last_hs(x, y)
Model Wrappers

In order to facilitate the interface between models and TimeSHAP, TimeSHAP implements ModelWrappers. These wrappers, used on the PyTorch tutorial notebook, allow for greater flexibility of explained models as they allow:

  • Batching logic: useful when using very large inputs or NSamples, which cannot fit on GPU memory, and therefore batching mechanisms are required;
  • Input format/type: useful when your model does not work with numpy arrays. This is the case of our provided PyToch example;

TimeSHAP Explanation Methods

TimeSHAP offers several methods to use depending on the desired explanations. Local methods provide detailed view of a model decision corresponding to a specific sequence being explained. Global methods aggregate local explanations of a given dataset to present a global view of the model.

Local Explanations

Pruning

local_pruning() performs the pruning algorithm on a given sequence with a given user defined tolerance and returns the pruning index along the information for plotting.

plot_temp_coalition_pruning() plots the pruning algorithm information calculated by local_pruning().

Event level explanations

local_event() calculates event level explanations of a given sequence with the user-given parameteres and returns the respective event-level explanations.

plot_event_heatmap() plots the event-level explanations calculated by local_event().

Feature level explanations

local_feat() calculates feature level explanations of a given sequence with the user-given parameteres and returns the respective feature-level explanations.

plot_feat_barplot() plots the feature-level explanations calculated by local_feat().

Cell level explanations

local_cell_level() calculates cell level explanations of a given sequence with the respective event- and feature-level explanations and user-given parameteres, returing the respective cell-level explanations.

plot_cell_level() plots the feature-level explanations calculated by local_cell_level().

Local Report

local_report() calculates TimeSHAP local explanations for a given sequence and plots them.

Global Explanations

Global pruning statistics

prune_all() performs the pruning algorithm on multiple given sequences.

pruning_statistics() calculates the pruning statistics for several user-given pruning tolerances using the pruning data calculated by prune_all(), returning a pandas.DataFrame with the statistics.

Global event level explanations

event_explain_all() calculates TimeSHAP event level explanations for multiple instances given user defined parameters.

plot_global_event() plots the global event-level explanations calculated by event_explain_all().

Global feature level explanations

feat_explain_all() calculates TimeSHAP feature level explanations for multiple instances given user defined parameters.

plot_global_feat() plots the global feature-level explanations calculated by feat_explain_all().

Global report

global_report() calculates TimeSHAP explanations for multiple instances, aggregating the explanations on two plots and returning them.

Tutorial

In order to demonstrate TimeSHAP interfaces and methods, you can consult AReM.ipynb. In this tutorial we get an open-source dataset, process it, train Pytorch recurrent model with it and use TimeSHAP to explain it, showcasing all previously described methods.

Additionally, we also train a TensorFlow model on the same dataset AReM_TF.ipynb.

Repository Structure

Citing TimeSHAP

@inproceedings{bento2021timeshap,
    author = {Bento, Jo\~{a}o and Saleiro, Pedro and Cruz, Andr\'{e} F. and Figueiredo, M\'{a}rio A.T. and Bizarro, Pedro},
    title = {TimeSHAP: Explaining Recurrent Models through Sequence Perturbations},
    year = {2021},
    isbn = {9781450383325},
    publisher = {Association for Computing Machinery},
    address = {New York, NY, USA},
    url = {https://doi.org/10.1145/3447548.3467166},
    doi = {10.1145/3447548.3467166},
    booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining},
    pages = {2565–2573},
    numpages = {9},
    keywords = {SHAP, Shapley values, TimeSHAP, XAI, RNN, explainability},
    location = {Virtual Event, Singapore},
    series = {KDD '21}
}
Comments
  • Error in running the example notebook (AReM_TF)

    Error in running the example notebook (AReM_TF)

    Nice work! I have been trying to run one of the tutorial notebooks in the repository (i.e., AReM_TF), but I faced an error. The notebook chunk that produces the error is:

    from timeshap.explainer import local_report, local_pruning
    
    pruning_dict = {'tol': 0.025}
    event_dict = {'rs': 42, 'nsamples': 320}
    feature_dict = {'rs': 42, 'nsamples': 320, 'feature_names': model_features, 'plot_features': plot_feats}
    cell_dict = {'rs': 42, 'nsamples': 320, 'top_x_feats': 2, 'top_x_events': 2}
    local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict,cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    

    and the produced error is as follows: image image

    It would be great if you could help me with this error.

    opened by aminnayebi 7
  • Unable to install timeshap package

    Unable to install timeshap package

    Hello @feedzaiadmin , @saleiro ,

    I am unable to run command "pip install timeshap" on my ubuntu system. It throws me below error:

    ERROR: Could not find a version that satisfies the requirement timeshap (from versions: none) ERROR: No matching distribution found for timeshap

    Seems to me a compatibility issue. I have tried with python versions - 3.6,3.8. Is there any specific version which supports it? Will be waiting for response.

    Thanks

    opened by vishants98 5
  • How to adapt the transformation function to account for variable sequence length?

    How to adapt the transformation function to account for variable sequence length?

    I am trying to use TimeSHAP on my use case. Per my understanding, in AReM example, the way you transform the data using the df_to_numpy function is to make a prediction for the last value of the sequence – see the screen below:

    image In the case of AReM tutorial data, the predictions are based on the whole sequence - all rows (rows ID 1-10) are being used for sequence ID 1 (light blue color) and the predictions are made for the Timestamp 10 (dark blue color; rows id 10). Later the light orange color is used (Row IDs 11-20) to predict a label marked as dark orange color (Row ID 20).

    In the case of my use case, the model predicts on a rolling-window basis and I would need predictions for every row (not only for a sequence). See the screen and explanation below. image Let's say my rolling window is 6 and Row IDs 1-6 (light green) are used to predict row 7 (dark green), later Row IDs 2-7 (light grey) are being used to predict Row ID 8 (dark grey), etc. When a new Sequence starts, we repeat the process, so we take Row IDs 11-16 and predict Row ID 17, etc. For my use case, it's important to evaluate the predictions for every Row ID, not only for the whole sequence.

    The problem which I am facing is that when I try to run the function get_avg_score_with_avg_event on the data defined as in the picture above I am getting the following error: image

    The way my data is transformed from 2D into 3D format is defined by the function below: image

    My question is whether it’s possible to make TimeSHAP work for the data which is transformed in a way described in my use case? When I use the transformation which is defined in your function df_to_numpy, I am not getting an error, however, it is not adapted to my use case.

    opened by grzechowiak 4
  • Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Foremost, thanks for the library, really great job!

    I am trying to reproduce your TimeSHAP Tutorial for TF and I am having an issue in the section for Global Explanations when running the global_report() function - screen below:

    image

    The error refers to encoding the \u2264 character which is a sign: <=. I was trying to solve that myself by modifying the pruning.py function according to the error by adding encoding="utf-8" in line 326 with open(file_path, 'a', newline='') as file:, however it didn't solve the problem. Any advice is very welcome!

    Also, for consistency, I want to mention that I had a problem loading the data - showing the error screens below. I was able to solve the problem only by deleting 2 datasets: cycling/dataset9.csv and cycling/dataset14csv and the rest of the code worked.

    1/2 image 2/2 image

    opened by grzechowiak 4
  • Timeshap for regression

    Timeshap for regression

    I am working on a time series forecasting problem using LSTM. Can I use timeshap for such a regression problem? Do you by chance have a demo for regression?

    Thanks

    opened by mgorjis 3
  • Plot Coalition Pruning is not working

    Plot Coalition Pruning is not working

    Hi, I am doing a project as part of a Master Thesis: I was testing your introductory notebook with the Tensorflow implementation, but this is not working because of an error raised by one of yours plot functions. The funny thing is that your other notebook, i.e., the one with the torch model, is working fine. I managed to find the issue and it was raised by an inner method called solve_negatives_method in src/timeshap/plot/pruning.py which would raise an InvalidKeyError caused by this specific row:

        df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    As far I understood, we are passing a list, usually made of only one value, to the pandas.DataFrame.at method which requires to pass an integer parameter and a column identifier. As such, one solution that works now could be:

        df.at[corresponding_row.index[0], 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    I want to thank you very much for the work done, and I will be happy to discuss further result of my thesis with you. :) Hoping for the best and looking forward to hear from you, Eric

    opened by Erhtric 2
  • TimeSHAP for text?

    TimeSHAP for text?

    I'm working with a 1-layer GRU for text classification that takes BERT embeddings at the input. Each input sequence is of the shape (sequence length, bert-embedding-dimension). I'm looking for word level attribution scores for each sequence's prediction. Currently with the captum integrated gradients and occlusion explainers, I get attribution scores that are almost always the last few words of each sequence. This seems like it's stemming from the directional processing of GRU - any thoughts?

    Do you think TimeSHAP would be applicable for my use case? I suppose I could consider each word as an event and each embedding dimension as a feature, then I could use the event level local explanations from the library? However, note that in my case, the events (i.e words) from the beginning of the sequence could be more important than those at the end of the sequence (i.e. most recent ones) - this violates the assumptions you use for your approximation (i.e pruning), so perhaps it's not applicable to text?

    opened by itsmemala 2
  • Intuition around pruning/baseline selection

    Intuition around pruning/baseline selection

    While calculating a global report, I'm currently running into errors that "Score difference between baseline and instance is too low < 0.1...Consider choosing another baseline." My baselines have been the average_event and average_sequence. I also notice that occasionally the values in the error change with different pruning tolerance, but not in a consistent way. Do you have advice for dealing with this? Thanks.

    opened by xydisla 2
  • CNN model

    CNN model

    I currently have a CNN model, and previously had to do some strange hacking to get time series importance values. Your package now shines a better light on this issue. For multivariate forecasts CNNs still fair well, sometimes better than RNNs, from what I can see there is no reason why using CNNs won't working using your software. Let me know if I have that wrong.

    opened by firmai 2
  • How to speed up TIMESHAP computation

    How to speed up TIMESHAP computation

    Hi all!

    The package itself is really interesting and intuitive to use. But I want to speed up TIMESHAP computation, Can I use gpu to calculate shapley values? I used a device parameter in TorchModelWrapper, but efficiency of GPU is too low to accelerate TIMESHAP computation. Any suggestion would be appreciated.

    opened by Changshu135 2
  • Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Hi all! When executing your TF example notebook without any changes, it fails at this line: local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)

    with the following error: InvalidIndexError: Int64Index([1], dtype='int64')

    Stack trace:

    TypeError                                 Traceback (most recent call last)
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3621, in Index.get_loc(self, key, method, tolerance)
       3620 try:
    -> 3621     return self._engine.get_loc(casted_key)
       3622 except KeyError as err:
    
    File pandas/_libs/index.pyx:136, in pandas._libs.index.IndexEngine.get_loc()
    
    File pandas/_libs/index.pyx:142, in pandas._libs.index.IndexEngine.get_loc()
    
    TypeError: 'Int64Index([1], dtype='int64')' is an invalid key
    
    During handling of the above exception, another exception occurred:
    
    InvalidIndexError                         Traceback (most recent call last)
    Input In [27], in <cell line: 7>()
          5 feature_dict = {'rs': 42, 'nsamples': 32000, 'feature_names': model_features, 'plot_features': plot_feats}
          6 cell_dict = {'rs': 42, 'nsamples': 32000, 'top_x_feats': 2, 'top_x_events': 2}
    ----> 7 local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    
    File ~/temp/timeshap/src/timeshap/explainer/local_methods.py:139, in local_report(f, data, pruning_dict, event_dict, feature_dict, cell_dict, entity_uuid, entity_col, time_col, model_features, baseline, verbose)
        137 pruning_idx = data.shape[1] + coal_prun_idx
        138 plot_lim = max(abs(coal_prun_idx)+10, 40)
    --> 139 pruning_plot = plot_temp_coalition_pruning(coal_plot_data, coal_prun_idx, plot_lim)
        141 event_data = local_event(f, data, event_dict, entity_uuid, entity_col, baseline, pruning_idx)
        142 event_plot = plot_event_heatmap(event_data)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:53, in plot_temp_coalition_pruning(df, pruned_idx, plot_limit, solve_negatives)
         51 df = df[df['t (event index)'] >= -plot_limit]
         52 if solve_negatives:
    ---> 53     df = solve_negatives_method(df)
         55 base = (alt.Chart(df).encode(
         56     x=alt.X("t (event index)", axis=alt.Axis(title='t (event index)', labelFontSize=15,
         57                           titleFontSize=15)),
       (...)
         70 )
         71 )
         73 area_chart = base.mark_area(opacity=0.5)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:47, in plot_temp_coalition_pruning.<locals>.solve_negatives_method(df)
         45 for idx, row in negative_values.iterrows():
         46     corresponding_row = df[np.logical_and(df['t (event index)'] == row['t (event index)'], ~(df['Coalition'] == row['Coalition']))]
    ---> 47     df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
         48     df.at[idx, 'Shapley Value'] = 0
         49 return df
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2281, in _AtIndexer.__setitem__(self, key, value)
       2278     self.obj.loc[key] = value
       2279     return
    -> 2281 return super().__setitem__(key, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2236, in _ScalarAccessIndexer.__setitem__(self, key, value)
       2233 if len(key) != self.ndim:
       2234     raise ValueError("Not enough indexers for scalar access (setting)!")
    -> 2236 self.obj._set_value(*key, value=value, takeable=self._takeable)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/frame.py:3869, in DataFrame._set_value(self, index, col, value, takeable)
       3867 else:
       3868     series = self._get_item_cache(col)
    -> 3869     loc = self.index.get_loc(index)
       3871 # setitem_inplace will do validation that may raise TypeError
       3872 #  or ValueError
       3873 series._mgr.setitem_inplace(loc, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3628, in Index.get_loc(self, key, method, tolerance)
       3623         raise KeyError(key) from err
       3624     except TypeError:
       3625         # If we have a listlike key, _check_indexing_error will raise
       3626         #  InvalidIndexError. Otherwise we fall through and re-raise
       3627         #  the TypeError.
    -> 3628         self._check_indexing_error(key)
       3629         raise
       3631 # GH#42269
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:5637, in Index._check_indexing_error(self, key)
       5633 def _check_indexing_error(self, key):
       5634     if not is_scalar(key):
       5635         # if key is not a scalar, directly raise an error (the code below
       5636         # would convert to numpy arrays and raise later any way) - GH29926
    -> 5637         raise InvalidIndexError(key)
    
    InvalidIndexError: Int64Index([1], dtype='int64')
    
    opened by JulianKlug 2
Owner
Feedzai
Feedzai
Convolutional Neural Network for 3D meshes in PyTorch

MeshCNN in PyTorch SIGGRAPH 2019 [Paper] [Project Page] MeshCNN is a general-purpose deep neural network for 3D triangular meshes, which can be used f

Rana Hanocka 1.4k Jan 04, 2023
Underwater industrial application yolov5m6

This project wins the intelligent algorithm contest finalist award and stands out from over 2000teams in China Underwater Robot Professional Contest, entering the final of China Underwater Robot Prof

8 Nov 09, 2022
MDMM - Learning multi-domain multi-modality I2I translation

Multi-Domain Multi-Modality I2I translation Pytorch implementation of multi-modality I2I translation for multi-domains. The project is an extension to

Hsin-Ying Lee 107 Nov 04, 2022
Code for SentiBERT: A Transferable Transformer-Based Architecture for Compositional Sentiment Semantics (ACL'2020).

SentiBERT Code for SentiBERT: A Transferable Transformer-Based Architecture for Compositional Sentiment Semantics (ACL'2020). https://arxiv.org/abs/20

Da Yin 66 Aug 13, 2022
code for "Self-supervised edge features for improved Graph Neural Network training",

Self-supervised edge features for improved Graph Neural Network training Data availability: Here is a link to the raw data for the organoids dataset.

Neal Ravindra 23 Dec 02, 2022
Collections for the lasted paper about multi-view clustering methods (papers, codes)

Multi-View Clustering Papers Collections for the lasted paper about multi-view clustering methods (papers, codes). There also exists some repositories

Andrew Guan 10 Sep 20, 2022
Supporting code for the Neograd algorithm

Neograd This repo supports the paper Neograd: Gradient Descent with a Near-Ideal Learning Rate, which introduces the algorithm "Neograd". The paper an

Michael Zimmer 12 May 01, 2022
An implementation of the research paper "Retina Blood Vessel Segmentation Using A U-Net Based Convolutional Neural Network"

Retina Blood Vessels Segmentation This is an implementation of the research paper "Retina Blood Vessel Segmentation Using A U-Net Based Convolutional

Srijarko Roy 23 Aug 20, 2022
PyTorch code for the paper "FIERY: Future Instance Segmentation in Bird's-Eye view from Surround Monocular Cameras"

FIERY This is the PyTorch implementation for inference and training of the future prediction bird's-eye view network as described in: FIERY: Future In

Wayve 406 Dec 24, 2022
MILK: Machine Learning Toolkit

MILK: MACHINE LEARNING TOOLKIT Machine Learning in Python Milk is a machine learning toolkit in Python. Its focus is on supervised classification with

Luis Pedro Coelho 610 Dec 14, 2022
NVIDIA Deep Learning Examples for Tensor Cores

NVIDIA Deep Learning Examples for Tensor Cores Introduction This repository provides State-of-the-Art Deep Learning examples that are easy to train an

NVIDIA Corporation 10k Dec 31, 2022
Compressed Video Action Recognition

Compressed Video Action Recognition Chao-Yuan Wu, Manzil Zaheer, Hexiang Hu, R. Manmatha, Alexander J. Smola, Philipp Krähenbühl. In CVPR, 2018. [Proj

Chao-Yuan Wu 479 Dec 26, 2022
Collection of generative models, e.g. GAN, VAE in Pytorch and Tensorflow.

Generative Models Collection of generative models, e.g. GAN, VAE in Pytorch and Tensorflow. Also present here are RBM and Helmholtz Machine. Note: Gen

Agustinus Kristiadi 7k Jan 02, 2023
Simply enable or disable your Nvidia dGPU

EnvyControl (WIP) Simply enable or disable your Nvidia dGPU Usage First clone this repo and install envycontrol with sudo pip install . CLI Turn off y

Victor Bayas 292 Jan 03, 2023
A check for whether the dependency jobs are all green.

alls-green A check for whether the dependency jobs are all green. Why? Do you have more than one job in your GitHub Actions CI/CD workflows setup? Do

Re:actors 33 Jan 03, 2023
Can we learn gradients by Hamiltonian Neural Networks?

Can we learn gradients by Hamiltonian Neural Networks? This project was carried out as part of the Optimization for Machine Learning course (CS-439) a

2 Aug 22, 2022
Implementation of the Point Transformer layer, in Pytorch

Point Transformer - Pytorch Implementation of the Point Transformer self-attention layer, in Pytorch. The simple circuit above seemed to have allowed

Phil Wang 501 Jan 03, 2023
Diffusion Normalizing Flow (DiffFlow) Neurips2021

Diffusion Normalizing Flow (DiffFlow) Reproduce setup environment The repo heavily depends on jam, a personal toolbox developed by Qsh.zh. The API may

76 Jan 01, 2023
Official Implementation of CVPR 2022 paper: "Mimicking the Oracle: An Initial Phase Decorrelation Approach for Class Incremental Learning"

(CVPR 2022) Mimicking the Oracle: An Initial Phase Decorrelation Approach for Class Incremental Learning ArXiv This repo contains Official Implementat

Yujun Shi 24 Nov 01, 2022
Pytorch implementation of Cut-Thumbnail in the paper Cut-Thumbnail:A Novel Data Augmentation for Convolutional Neural Network.

Cut-Thumbnail (Accepted at ACM MULTIMEDIA 2021) Tianshu Xie, Xuan Cheng, Xiaomin Wang, Minghui Liu, Jiali Deng, Tao Zhou, Ming Liu This is the officia

3 Apr 12, 2022