Implementation of the paper "Shapley Explanation Networks"

Overview

Shapley Explanation Networks

Implementation of the paper "Shapley Explanation Networks" at ICLR 2021. Note that this repo heavily uses the experimental feature of named tensors in PyTorch. As it was really confusing to implement the ideas for the authors, we find it tremendously easier to use this feature.

Dependencies

For running only ShapNets, one would mostly only need PyTorch, NumPy, and SciPy.

Usage

For a Shapley Module:

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule

b_size = 3
features = 4
out = 1
dims = ModuleDimensions(
    features=features,
    in_channel=1,
    out_channel=out
)

sm = ShapleyModule(
    inner_function=nn.Linear(features, out),
    dimensions=dims
)
sm(torch.randn(b_size, features), explain=True)

For a Shallow ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, OverlappingShallowShapleyNetwork

batch_size = 32
class_num = 10
dim = 32

overlapping_modules = [
    ShapleyModule(
        inner_function=nn.Sequential(nn.Linear(2, class_num)),
        dimensions=ModuleDimensions(
            features=2, in_channel=1, out_channel=class_num
        ),
    ) for _ in range(dim * (dim - 1) // 2)
]
shallow_shapnet = OverlappingShallowShapleyNetwork(
    list_modules=overlapping_modules
)
inputs = torch.randn(batch_size, dim, ), )
shallow_shapnet(torch.randn(batch_size, dim, ), )
output, bias = shallow_shapnet(inputs, explain=True, )

For a Deep ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, ShallowShapleyNetwork, DeepShapleyNetwork

dim = 32
dim_input_channels = 1
class_num = 10
inputs = torch.randn(32, dim, ), )


dims = ModuleDimensions(
    features=dim,
    in_channel=dim_input_channels,
    out_channel=class_num
)
deep_shapnet = DeepShapleyNetwork(
    list_shapnets=[
        ShallowShapleyNetwork(
            module_dict=nn.ModuleDict({
                "(0, 2)": ShapleyModule(
                    inner_function=nn.Linear(2, class_num),
                    dimensions=ModuleDimensions(
                        features=2, in_channel=1, out_channel=class_num
                    )
                )},
            ),
            dimensions=ModuleDimensions(dim, 1, class_num)
        ),
    ],
)
deep_shapnet(inputs)
outputs = deep_shapnet(inputs, explain=True, )

For a vision model:

import numpy as np
import torch
import torch.nn as nn

# =============================================================================
# Imports {\sc ShapNet}
# =============================================================================
from ShapNet import DeepConvShapNet, ShallowConvShapleyNetwork, ShapleyModule
from ShapNet.utils import ModuleDimensions, NAME_HEIGHT, NAME_WIDTH, \
    process_list_sizes

num_channels = 3
num_classes = 10
height = 32
width = 32
list_channels = [3, 16, 10]
pruning = [0.2, 0.]
kernel_sizes = process_list_sizes([2, (1, 3), ])
dilations = process_list_sizes([1, 2])
paddings = process_list_sizes([0, 0])
strides = process_list_sizes([1, 1])

args = {
    "list_shapnets": [
        ShallowConvShapleyNetwork(
            shapley_module=ShapleyModule(
                inner_function=nn.Sequential(
                    nn.Linear(
                        np.prod(kernel_sizes[i]) * list_channels[i],
                        list_channels[i + 1]),
                    nn.LeakyReLU()
                ),
                dimensions=ModuleDimensions(
                    features=int(np.prod(kernel_sizes[i])),
                    in_channel=list_channels[i],
                    out_channel=list_channels[i + 1])
            ),
            reference_values=None,
            kernel_size=kernel_sizes[i],
            dilation=dilations[i],
            padding=paddings[i],
            stride=strides[i]
        ) for i in range(len(list_channels) - 1)
    ],
    "reference_values": None,
    "residual": False,
    "named_output": False,
    "pruning": pruning
}

dcs = DeepConvShapNet(**args)

Citation

If this is useful, you could cite our work as

@inproceedings{
wang2021shapley,
title={Shapley Explanation Networks},
author={Rui Wang and Xiaoqian Wang and David I. Inouye},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=vsU0efpivw}
}
Owner
Prof. David I. Inouye's research lab at Purdue University.
Survival analysis in Python

What is survival analysis and why should I learn it? Survival analysis was originally developed and applied heavily by the actuarial and medical commu

Cameron Davidson-Pilon 2k Jan 08, 2023
Image Segmentation Evaluation

Image Segmentation Evaluation Martin Keršner, [email protected] Evaluation

Martin Kersner 273 Oct 28, 2022
A TensorFlow Implementation of "Deep Multi-Scale Video Prediction Beyond Mean Square Error" by Mathieu, Couprie & LeCun.

Adversarial Video Generation This project implements a generative adversarial network to predict future frames of video, as detailed in "Deep Multi-Sc

Matt Cooper 704 Nov 26, 2022
[NeurIPS 2020] Blind Video Temporal Consistency via Deep Video Prior

pytorch-deep-video-prior (DVP) Official PyTorch implementation for NeurIPS 2020 paper: Blind Video Temporal Consistency via Deep Video Prior TensorFlo

Yazhou XING 90 Oct 19, 2022
This repository is the official implementation of the Hybrid Self-Attention NEAT algorithm.

This repository is the official implementation of the Hybrid Self-Attention NEAT algorithm. It contains the code to reproduce the results presented in the original paper: https://arxiv.org/abs/2112.0

Saman Khamesian 6 Dec 13, 2022
Put blind watermark into a text with python

text_blind_watermark Put blind watermark into a text. Can be used in Wechat dingding ... How to Use install pip install text_blind_watermark Alice Pu

郭飞 164 Dec 30, 2022
SAPIEN Manipulation Skill Benchmark

ManiSkill Benchmark SAPIEN Manipulation Skill Benchmark (abbreviated as ManiSkill, pronounced as "Many Skill") is a large-scale learning-from-demonstr

Hao Su's Lab, UCSD 107 Jan 08, 2023
Pytorch implementation of our paper LIMUSE: LIGHTWEIGHT MULTI-MODAL SPEAKER EXTRACTION.

LiMuSE Overview Pytorch implementation of our paper LIMUSE: LIGHTWEIGHT MULTI-MODAL SPEAKER EXTRACTION. LiMuSE explores group communication on a multi

Auditory Model and Cognitive Computing Lab 17 Oct 26, 2022
Supplementary code for SIGGRAPH 2021 paper: Discovering Diverse Athletic Jumping Strategies

SIGGRAPH 2021: Discovering Diverse Athletic Jumping Strategies project page paper demo video Prerequisites Important Notes We suspect there are bugs i

54 Dec 06, 2022
Pywonderland - A tour in the wonderland of math with python.

A Tour in the Wonderland of Math with Python A collection of python scripts for drawing beautiful figures and animating interesting algorithms in math

Zhao Liang 4.1k Jan 03, 2023
PyTorch implementation of Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy

Anomaly Transformer in PyTorch This is an implementation of Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy. This pape

spencerbraun 160 Dec 19, 2022
CAUSE: Causality from AttribUtions on Sequence of Events

CAUSE: Causality from AttribUtions on Sequence of Events

Wei Zhang 21 Dec 01, 2022
Deploy recommendation engines with Edge Computing

RecoEdge: Bringing Recommendations to the Edge A one stop solution to build your recommendation models, train them and, deploy them in a privacy prese

NimbleEdge 131 Jan 02, 2023
Official implementation of the ICLR 2021 paper

You Only Need Adversarial Supervision for Semantic Image Synthesis Official PyTorch implementation of the ICLR 2021 paper "You Only Need Adversarial S

Bosch Research 272 Dec 28, 2022
Finding Donors for CharityML

Finding-Donors-for-CharityML - Investigated factors that affect the likelihood of charity donations being made based on real census data.

Moamen Abdelkawy 1 Dec 30, 2021
Scripts of Machine Learning Algorithms from Scratch. Implementations of machine learning models and algorithms using nothing but NumPy with a focus on accessibility. Aims to cover everything from basic to advance.

Algo-ScriptML Python implementations of some of the fundamental Machine Learning models and algorithms from scratch. The goal of this project is not t

Algo Phantoms 81 Nov 26, 2022
BboxToolkit is a tiny library of special bounding boxes.

BboxToolkit is a light codebase collecting some practical functions for the special-shape detection, such as oriented detection

jbwang1997 73 Jan 01, 2023
Quasi-Dense Similarity Learning for Multiple Object Tracking, CVPR 2021 (Oral)

Quasi-Dense Tracking This is the offical implementation of paper Quasi-Dense Similarity Learning for Multiple Object Tracking. We present a trailer th

ETH VIS Research Group 327 Dec 27, 2022
Algorithmic trading using machine learning.

Algorithmic Trading This machine learning algorithm was built using Python 3 and scikit-learn with a Decision Tree Classifier. The program gathers sto

Sourav Biswas 101 Nov 10, 2022