Adaptive, interpretable wavelets across domains (NeurIPS 2021)

Overview

Adaptive wavelets

Wavelets which adapt given data (and optionally a pre-trained model). This yields models which are faster, more compressible, and more interpretable.

📚 docs 📖 demo notebooks

Quickstart

Installation: pip install awave or clone the repo and run python setup.py install from the repo directory

Then, can use the core functions (see simplest example in notebooks/demo_simple_2d.ipynb or notebooks/demo_simple_1d.ipynb). See the docs for more information on arguments for these functions.

Given some data X, you can run the following:

from awave.utils.misc import get_wavefun
from awave.transform2d import DWT2d

wt = DWT2d(wave='db5', J=4)
wt.fit(X=X, lr=1e-1, num_epochs=10)  # this function alternatively accepts a dataloader
X_sparse = wt(X)  # uses the learned adaptive wavelet
phi, psi, x = get_wavefun(wt)  # can also inspect the learned adaptive wavelet

To distill a pretrained model named model, simply pass it as an additional argument to the fit function:

wt.fit(X=X, pretrained_model=model,
       lr=1e-1, num_epochs=10,
       lamL1attr=5) # control how much to regularize the model's attributions

Background

Official code for using / reproducing AWD from the paper "Adaptive wavelet distillation from neural networks through interpretations" (Ha et al. NeurIPS, 2021).
Abstract: Recent deep-learning models have achieved impressive prediction performance, but often sacrifice interpretability and computational efficiency. Interpretability is crucial in many disciplines, such as science and medicine, where models must be carefully vetted or where interpretation is the goal itself. Moreover, interpretable models are concise and often yield computational efficiency. Here, we propose adaptive wavelet distillation (AWD), a method which aims to distill information from a trained neural network into a wavelet transform. Specifically, AWD penalizes feature attributions of a neural network in the wavelet domain to learn an effective multi-resolution wavelet transform. The resulting model is highly predictive, concise, computationally efficient, and has properties (such as a multi-scale structure) which make it easy to interpret. In close collaboration with domain experts, we showcase how AWD addresses challenges in two real-world settings: cosmological parameter inference and molecular-partner prediction. In both cases, AWD yields a scientifically interpretable and concise model which gives predictive performance better than state-of-the-art neural networks. Moreover, AWD identifies predictive features that are scientifically meaningful in the context of respective domains.
Also provides an implementation for "Learning Sparse Wavelet Representations" (Recoskie & Mann, 2018)
Abstract: In this work we propose a method for learning wavelet filters directly from data. We accomplish this by framing the discrete wavelet transform as a modified convolutional neural network. We introduce an autoencoder wavelet transform network that is trained using gradient descent. We show that the model is capable of learning structured wavelet filters from synthetic and real data. The learned wavelets are shown to be similar to traditional wavelets that are derived using Fourier methods. Our method is simple to implement and easily incorporated into neural network architectures. A major advantage to our model is that we can learn from raw audio data.

Related work

  • TRIM (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • ACD (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • CDEP (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • DAC (arXiv 2019 pdf, github) - finds disentangled interpretations for random forests
  • PDR framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

If this package is useful for you, please cite the following!

@article{ha2021adaptive,
  title={Adaptive wavelet distillation from neural networks through interpretations},
  author={Ha, Wooseok and Singh, Chandan and Lanusse, Francois and Song, Eli and Dang, Song and He, Kangmin and Upadhyayula, Srigokul and Yu, Bin},
  journal={arXiv preprint arXiv:2107.09145},
  year={2021}
}
Owner
Yu Group
Bin Yu Group at UC Berkeley
Yu Group
Spectral Temporal Graph Neural Network (StemGNN in short) for Multivariate Time-series Forecasting

Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting This repository is the official implementation of Spectral Temporal Gr

Microsoft 306 Dec 29, 2022
Implementation of "Semi-supervised Domain Adaptive Structure Learning"

Semi-supervised Domain Adaptive Structure Learning - ASDA This repo contains the source code and dataset for our ASDA paper. Illustration of the propo

3 Dec 13, 2021
Using deep learning to predict gene structures of the coding genes in DNA sequences of Arabidopsis thaliana

DeepGeneAnnotator: A tool to annotate the gene in the genome The master thesis of the "Using deep learning to predict gene structures of the coding ge

Ching-Tien Wang 3 Sep 09, 2022
Efficient 6-DoF Grasp Generation in Cluttered Scenes

Contact-GraspNet Contact-GraspNet: Efficient 6-DoF Grasp Generation in Cluttered Scenes Martin Sundermeyer, Arsalan Mousavian, Rudolph Triebel, Dieter

NVIDIA Research Projects 148 Dec 28, 2022
AAAI 2022: Stationary diffusion state neural estimation

Stationary Diffusion State Neural Estimation Although many graph-based clustering methods attempt to model the stationary diffusion state in their obj

绽琨 33 Nov 24, 2022
Machine Learning Time-Series Platform

cesium: Open-Source Platform for Time Series Inference Summary cesium is an open source library that allows users to: extract features from raw time s

632 Dec 26, 2022
Code for models used in Bashiri et al., "A Flow-based latent state generative model of neural population responses to natural images".

A Flow-based latent state generative model of neural population responses to natural images Code for "A Flow-based latent state generative model of ne

Sinz Lab 5 Aug 26, 2022
This is a classifier which basically predicts whether there is a gun law in a state or not, depending on various things like murder rates etc.

Gun-Laws-Classifier This is a classifier which basically predicts whether there is a gun law in a state or not, depending on various things like murde

Awais Saleem 1 Jan 20, 2022
DRIFT is a tool for Diachronic Analysis of Scientific Literature.

About DRIFT is a tool for Diachronic Analysis of Scientific Literature. The application offers user-friendly and customizable utilities for two modes:

Rajaswa Patil 108 Dec 12, 2022
Code of the paper "Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition"

SEW (Squeezed and Efficient Wav2vec) The repo contains the code of the paper "Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speec

ASAPP Research 67 Dec 01, 2022
(SIGIR2020) “Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback’’

Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback About This repository accompanies the real-world experiments conducted i

yuta-saito 19 Dec 01, 2022
Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”

Official implementation for TransDA Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”. Overview: Result: Prerequisites:

stanley 54 Dec 22, 2022
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 05, 2022
Reverse engineering Rosetta 2 in M1 Mac

Project Champollion About this project Rosetta 2 is an emulation mechanism to run the x86_64 applications on Arm-based Apple Silicon with Ahead-Of-Tim

FFRI Security, Inc. 258 Jan 07, 2023
NOMAD - A blackbox optimization software

################################################################################### #

Blackbox Optimization 78 Dec 29, 2022
Spearmint Bayesian optimization codebase

Spearmint Spearmint is a software package to perform Bayesian optimization. The Software is designed to automatically run experiments (thus the code n

Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton 1.5k Dec 29, 2022
Weakly supervised medical named entity classification

Trove Trove is a research framework for building weakly supervised (bio)medical named entity recognition (NER) and other entity attribute classifiers

60 Nov 18, 2022
Ἀνατομή is a PyTorch library to analyze representation of neural networks

Ἀνατομή is a PyTorch library to analyze representation of neural networks

Ryuichiro Hataya 50 Dec 05, 2022
Yolo Traffic Light Detection With Python

Yolo-Traffic-Light-Detection This project is based on detecting the Traffic light. Pretained data is used. This application entertained both real time

Ananta Raj Pant 2 Aug 08, 2022
In Search of Probeable Generalization Measures

In Search of Probeable Generalization Measures Exciting News! In Search of Probeable Generalization Measures has been accepted to the International Co

Mahdi S. Hosseini 6 Sep 11, 2022