Implementation of the paper "Shapley Explanation Networks"

Overview

Shapley Explanation Networks

Implementation of the paper "Shapley Explanation Networks" at ICLR 2021. Note that this repo heavily uses the experimental feature of named tensors in PyTorch. As it was really confusing to implement the ideas for the authors, we find it tremendously easier to use this feature.

Dependencies

For running only ShapNets, one would mostly only need PyTorch, NumPy, and SciPy.

Usage

For a Shapley Module:

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule

b_size = 3
features = 4
out = 1
dims = ModuleDimensions(
    features=features,
    in_channel=1,
    out_channel=out
)

sm = ShapleyModule(
    inner_function=nn.Linear(features, out),
    dimensions=dims
)
sm(torch.randn(b_size, features), explain=True)

For a Shallow ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, OverlappingShallowShapleyNetwork

batch_size = 32
class_num = 10
dim = 32

overlapping_modules = [
    ShapleyModule(
        inner_function=nn.Sequential(nn.Linear(2, class_num)),
        dimensions=ModuleDimensions(
            features=2, in_channel=1, out_channel=class_num
        ),
    ) for _ in range(dim * (dim - 1) // 2)
]
shallow_shapnet = OverlappingShallowShapleyNetwork(
    list_modules=overlapping_modules
)
inputs = torch.randn(batch_size, dim, ), )
shallow_shapnet(torch.randn(batch_size, dim, ), )
output, bias = shallow_shapnet(inputs, explain=True, )

For a Deep ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, ShallowShapleyNetwork, DeepShapleyNetwork

dim = 32
dim_input_channels = 1
class_num = 10
inputs = torch.randn(32, dim, ), )


dims = ModuleDimensions(
    features=dim,
    in_channel=dim_input_channels,
    out_channel=class_num
)
deep_shapnet = DeepShapleyNetwork(
    list_shapnets=[
        ShallowShapleyNetwork(
            module_dict=nn.ModuleDict({
                "(0, 2)": ShapleyModule(
                    inner_function=nn.Linear(2, class_num),
                    dimensions=ModuleDimensions(
                        features=2, in_channel=1, out_channel=class_num
                    )
                )},
            ),
            dimensions=ModuleDimensions(dim, 1, class_num)
        ),
    ],
)
deep_shapnet(inputs)
outputs = deep_shapnet(inputs, explain=True, )

For a vision model:

import numpy as np
import torch
import torch.nn as nn

# =============================================================================
# Imports {\sc ShapNet}
# =============================================================================
from ShapNet import DeepConvShapNet, ShallowConvShapleyNetwork, ShapleyModule
from ShapNet.utils import ModuleDimensions, NAME_HEIGHT, NAME_WIDTH, \
    process_list_sizes

num_channels = 3
num_classes = 10
height = 32
width = 32
list_channels = [3, 16, 10]
pruning = [0.2, 0.]
kernel_sizes = process_list_sizes([2, (1, 3), ])
dilations = process_list_sizes([1, 2])
paddings = process_list_sizes([0, 0])
strides = process_list_sizes([1, 1])

args = {
    "list_shapnets": [
        ShallowConvShapleyNetwork(
            shapley_module=ShapleyModule(
                inner_function=nn.Sequential(
                    nn.Linear(
                        np.prod(kernel_sizes[i]) * list_channels[i],
                        list_channels[i + 1]),
                    nn.LeakyReLU()
                ),
                dimensions=ModuleDimensions(
                    features=int(np.prod(kernel_sizes[i])),
                    in_channel=list_channels[i],
                    out_channel=list_channels[i + 1])
            ),
            reference_values=None,
            kernel_size=kernel_sizes[i],
            dilation=dilations[i],
            padding=paddings[i],
            stride=strides[i]
        ) for i in range(len(list_channels) - 1)
    ],
    "reference_values": None,
    "residual": False,
    "named_output": False,
    "pruning": pruning
}

dcs = DeepConvShapNet(**args)

Citation

If this is useful, you could cite our work as

@inproceedings{
wang2021shapley,
title={Shapley Explanation Networks},
author={Rui Wang and Xiaoqian Wang and David I. Inouye},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=vsU0efpivw}
}
Owner
Prof. David I. Inouye's research lab at Purdue University.
AFL binary instrumentation

E9AFL --- Binary AFL E9AFL inserts American Fuzzy Lop (AFL) instrumentation into x86_64 Linux binaries. This allows binaries to be fuzzed without the

242 Dec 12, 2022
Algorithmic Trading using RNN

Deep-Trading This an implementation adapted from Rachnog Neural networks for algorithmic trading. Part One — Simple time series forecasting and this c

Hazem Nomer 29 Sep 04, 2022
An example showing how to use jax to train resnet50 on multi-node multi-GPU

jax-multi-gpu-resnet50-example This repo shows how to use jax for multi-node multi-GPU training. The example is adapted from the resnet50 example in d

Yangzihao Wang 20 Jul 04, 2022
"Segmenter: Transformer for Semantic Segmentation" reproduced via mmsegmentation

Segmenter-based-on-OpenMMLab "Segmenter: Transformer for Semantic Segmentation, arxiv 2105.05633." reproduced via mmsegmentation. We reproduce Segment

EricKani 22 Feb 24, 2022
Scripts and a shader to get you started on setting up an exported Koikatsu character in Blender.

KK Blender Shader Pack A plugin and a shader to get you started with setting up an exported Koikatsu character in Blender. The plugin is a Blender add

166 Jan 01, 2023
CLADE - Efficient Semantic Image Synthesis via Class-Adaptive Normalization (TPAMI 2021)

Efficient Semantic Image Synthesis via Class-Adaptive Normalization (Accepted by TPAMI)

tzt 49 Nov 17, 2022
Lecture materials for Cornell CS5785 Applied Machine Learning (Fall 2021)

Applied Machine Learning (Cornell CS5785, Fall 2021) This repo contains executable course notes and slides for the Applied ML course at Cornell and Co

Volodymyr Kuleshov 103 Dec 31, 2022
Deep Learning Training Scripts With Python

Deep Learning Training Scripts DNN Frameworks Caffe PyTorch Tensorflow CNN Models VGG ResNet DenseNet Inception Language Modeling GatedCNN-LM Attentio

Multicore Computing Research Lab 16 Dec 15, 2022
HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands

HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands Oral Presentation, 3DV 2021 Korrawe Karunratanakul, Adrian Spurr, Zicong

Korrawe Karunratanakul 43 Oct 07, 2022
Cross-lingual Transfer for Speech Processing using Acoustic Language Similarity

Cross-lingual Transfer for Speech Processing using Acoustic Language Similarity Indic TTS Samples can be found at https://peter-yh-wu.github.io/cross-

Peter Wu 1 Nov 12, 2022
School of Artificial Intelligence at the Nanjing University (NJU)School of Artificial Intelligence at the Nanjing University (NJU)

F-Principle This is an exercise problem of the digital signal processing (DSP) course at School of Artificial Intelligence at the Nanjing University (

Thyrix 5 Nov 23, 2022
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Google 208 Dec 14, 2022
Pytorch implementation for A-NeRF: Articulated Neural Radiance Fields for Learning Human Shape, Appearance, and Pose

A-NeRF: Articulated Neural Radiance Fields for Learning Human Shape, Appearance, and Pose Paper | Website | Data A-NeRF: Articulated Neural Radiance F

Shih-Yang Su 172 Dec 22, 2022
Funnels: Exact maximum likelihood with dimensionality reduction.

Funnels This repository contains the code needed to reproduce the experiments from the paper: Funnels: Exact maximum likelihood with dimensionality re

2 Apr 21, 2022
Anatomy of Matplotlib -- tutorial developed for the SciPy conference

Introduction This tutorial is a complete re-imagining of how one should teach users the matplotlib library. Hopefully, this tutorial may serve as insp

Matplotlib Developers 1.1k Dec 29, 2022
Benchmark datasets, data loaders, and evaluators for graph machine learning

Overview The Open Graph Benchmark (OGB) is a collection of benchmark datasets, data loaders, and evaluators for graph machine learning. Datasets cover

1.5k Jan 05, 2023
NeRViS: Neural Re-rendering for Full-frame Video Stabilization

Neural Re-rendering for Full-frame Video Stabilization

Yu-Lun Liu 9 Jun 17, 2022
Crossover Learning for Fast Online Video Instance Segmentation (ICCV 2021)

TL;DR: CrossVIS (Crossover Learning for Fast Online Video Instance Segmentation) proposes a novel crossover learning paradigm to fully leverage rich c

Hust Visual Learning Team 79 Nov 25, 2022
Evolving neural network parameters in JAX.

Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc

Trevor Thackston 6 Feb 12, 2022
Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking

Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking We revisit and address issues with Oxford 5k and Paris 6k image retrieval benchm

Filip Radenovic 188 Dec 17, 2022