Adaptive, interpretable wavelets across domains (NeurIPS 2021)

Overview

Adaptive wavelets

Wavelets which adapt given data (and optionally a pre-trained model). This yields models which are faster, more compressible, and more interpretable.

📚 docs â€ĸ 📖 demo notebooks

Quickstart

Installation: pip install awave or clone the repo and run python setup.py install from the repo directory

Then, can use the core functions (see simplest example in notebooks/demo_simple_2d.ipynb or notebooks/demo_simple_1d.ipynb). See the docs for more information on arguments for these functions.

Given some data X, you can run the following:

from awave.utils.misc import get_wavefun
from awave.transform2d import DWT2d

wt = DWT2d(wave='db5', J=4)
wt.fit(X=X, lr=1e-1, num_epochs=10)  # this function alternatively accepts a dataloader
X_sparse = wt(X)  # uses the learned adaptive wavelet
phi, psi, x = get_wavefun(wt)  # can also inspect the learned adaptive wavelet

To distill a pretrained model named model, simply pass it as an additional argument to the fit function:

wt.fit(X=X, pretrained_model=model,
       lr=1e-1, num_epochs=10,
       lamL1attr=5) # control how much to regularize the model's attributions

Background

Official code for using / reproducing AWD from the paper "Adaptive wavelet distillation from neural networks through interpretations" (Ha et al. NeurIPS, 2021).
Abstract: Recent deep-learning models have achieved impressive prediction performance, but often sacrifice interpretability and computational efficiency. Interpretability is crucial in many disciplines, such as science and medicine, where models must be carefully vetted or where interpretation is the goal itself. Moreover, interpretable models are concise and often yield computational efficiency. Here, we propose adaptive wavelet distillation (AWD), a method which aims to distill information from a trained neural network into a wavelet transform. Specifically, AWD penalizes feature attributions of a neural network in the wavelet domain to learn an effective multi-resolution wavelet transform. The resulting model is highly predictive, concise, computationally efficient, and has properties (such as a multi-scale structure) which make it easy to interpret. In close collaboration with domain experts, we showcase how AWD addresses challenges in two real-world settings: cosmological parameter inference and molecular-partner prediction. In both cases, AWD yields a scientifically interpretable and concise model which gives predictive performance better than state-of-the-art neural networks. Moreover, AWD identifies predictive features that are scientifically meaningful in the context of respective domains.
Also provides an implementation for "Learning Sparse Wavelet Representations" (Recoskie & Mann, 2018)
Abstract: In this work we propose a method for learning wavelet filters directly from data. We accomplish this by framing the discrete wavelet transform as a modified convolutional neural network. We introduce an autoencoder wavelet transform network that is trained using gradient descent. We show that the model is capable of learning structured wavelet filters from synthetic and real data. The learned wavelets are shown to be similar to traditional wavelets that are derived using Fourier methods. Our method is simple to implement and easily incorporated into neural network architectures. A major advantage to our model is that we can learn from raw audio data.

Related work

  • TRIM (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • ACD (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • CDEP (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • DAC (arXiv 2019 pdf, github) - finds disentangled interpretations for random forests
  • PDR framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

If this package is useful for you, please cite the following!

@article{ha2021adaptive,
  title={Adaptive wavelet distillation from neural networks through interpretations},
  author={Ha, Wooseok and Singh, Chandan and Lanusse, Francois and Song, Eli and Dang, Song and He, Kangmin and Upadhyayula, Srigokul and Yu, Bin},
  journal={arXiv preprint arXiv:2107.09145},
  year={2021}
}
Owner
Yu Group
Bin Yu Group at UC Berkeley
Yu Group
A sequence of Jupyter notebooks featuring the 12 Steps to Navier-Stokes

CFD Python Please cite as: Barba, Lorena A., and Forsyth, Gilbert F. (2018). CFD Python: the 12 steps to Navier-Stokes equations. Journal of Open Sour

Barba group 2.6k Dec 30, 2022
MMRazor: a model compression toolkit for model slimming and AutoML

Documentation: https://mmrazor.readthedocs.io/ English | įŽ€äŊ“中文 Introduction MMRazor is a model compression toolkit for model slimming and AutoML, which

OpenMMLab 899 Jan 02, 2023
Disturbing Target Values for Neural Network regularization: attacking the loss layer to prevent overfitting

Disturbing Target Values for Neural Network regularization: attacking the loss layer to prevent overfitting 1. Classification Task PyTorch implementat

Yongho Kim 0 Apr 24, 2022
Code for a real-time distributed cooperative slam(RDC-SLAM) system for ROS compatible platforms.

RDC-SLAM This repository contains code for a real-time distributed cooperative slam(RDC-SLAM) system for ROS compatible platforms. The system takes in

40 Nov 19, 2022
The dataset of tweets pulling from Twitters with keyword: Hydroxychloroquine, location: US, Time: 2020

HCQ_Tweet_Dataset: FREE to Download. Keywords: HCQ, hydroxychloroquine, tweet, twitter, COVID-19 This dataset is associated with the paper "Understand

2 Mar 16, 2022
Mind the Trade-off: Debiasing NLU Models without Degrading the In-distribution Performance

Models for natural language understanding (NLU) tasks often rely on the idiosyncratic biases of the dataset, which make them brittle against test cases outside the training distribution.

Ubiquitous Knowledge Processing Lab 22 Jan 02, 2023
Code & Data for Enhancing Photorealism Enhancement

Enhancing Photorealism Enhancement Stephan R. Richter, Hassan Abu AlHaija, Vladlen Koltun Paper | Website (with side-by-side comparisons) | Video (Pap

Intelligent Systems Lab Org 1.1k Dec 31, 2022
This MVP data web app uses the Streamlit framework and Facebook's Prophet forecasting package to generate a dynamic forecast from your own data.

📈 Automated Time Series Forecasting Background: This MVP data web app uses the Streamlit framework and Facebook's Prophet forecasting package to gene

Zach Renwick 42 Jan 04, 2023
NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

100 Sep 28, 2022
Implementation of paper "Towards a Unified View of Parameter-Efficient Transfer Learning"

A Unified Framework for Parameter-Efficient Transfer Learning This is the official implementation of the paper: Towards a Unified View of Parameter-Ef

Junxian He 216 Dec 29, 2022
BossNAS: Exploring Hybrid CNN-transformers with Block-wisely Self-supervised Neural Architecture Search

BossNAS This repository contains PyTorch evaluation code, retraining code and pretrained models of our paper: BossNAS: Exploring Hybrid CNN-transforme

Changlin Li 127 Dec 26, 2022
This solves the autonomous driving issue which is supported by deep learning technology. Given a video, it splits into images and predicts the angle of turning for each frame.

Self Driving Car An autonomous car (also known as a driverless car, self-driving car, and robotic car) is a vehicle that is capable of sensing its env

Sagor Saha 4 Sep 04, 2021
Power Core Simulator!

Power Core Simulator Power Core Simulator is a simulator based off the Roblox game "Pinewood Builders Computer Core". In this simulator, you can choos

BananaJeans 1 Nov 13, 2021
Implementation of Deep Deterministic Policy Gradiet Algorithm in Tensorflow

ddpg-aigym Deep Deterministic Policy Gradient Implementation of Deep Deterministic Policy Gradiet Algorithm (Lillicrap et al.arXiv:1509.02971.) in Ten

Steven Spielberg P 247 Dec 07, 2022
iris - Open Source Photos Platform Powered by PyTorch

Open Source Photos Platform Powered by PyTorch. Submission for PyTorch Annual Hackathon 2021.

Omkar Prabhu 137 Sep 10, 2022
MQBench Quantization Aware Training with PyTorch

MQBench Quantization Aware Training with PyTorch I am using MQBench(Model Quantization Benchmark)(http://mqbench.tech/) to quantize the model for depl

Ling Zhang 29 Nov 18, 2022
Automatic Data-Regularized Actor-Critic (Auto-DrAC)

Auto-DrAC: Automatic Data-Regularized Actor-Critic This is a PyTorch implementation of the methods proposed in Automatic Data Augmentation for General

89 Dec 13, 2022
Python scripts using the Mediapipe models for Halloween.

Mediapipe-Halloween-Examples Python scripts using the Mediapipe models for Halloween. WHY Mainly for fun. But this repository also includes useful exa

Ibai Gorordo 23 Jan 06, 2023
Non-Homogeneous Poisson Process Intensity Modeling and Estimation using Measure Transport

Non-Homogeneous Poisson Process Intensity Modeling and Estimation using Measure Transport This GitHub page provides code for reproducing the results i

Andrew Zammit Mangion 1 Nov 08, 2021
Trajectory Prediction with Graph-based Dual-scale Context Fusion

DSP: Trajectory Prediction with Graph-based Dual-scale Context Fusion Introduction This is the project page of the paper Lu Zhang, Peiliang Li, Jing C

HKUST Aerial Robotics Group 103 Jan 04, 2023