Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Overview

Bayes-Newton

Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Will Wilkinson.

Bayes-Newton provides a unifying view of approximate Bayesian inference, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page.

Installation

pip install bayesnewton

Example

Given some inputs x and some data y, you can construct a Bayes-Newton model as follows,

kern = bayesnewton.kernels.Matern52()
lik = bayesnewton.likelihoods.Gaussian()
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)

The training loop (inference and hyperparameter learning) is then set up using objax's built in functionality:

lr_adam = 0.1
lr_newton = 1
opt_hypers = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())

@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
    model.inference(lr=lr_newton, **inf_args)  # perform inference and update variational params
    dE, E = energy(**inf_args)  # compute energy and its gradients w.r.t. hypers
    opt_hypers(lr_adam, dE)  # update the hyperparameters
    return E

As we are using JAX, we can JIT compile the training loop:

train_op = objax.Jit(train_op)

and then run the training loop,

iters = 20
for i in range(1, iters + 1):
    loss = train_op()

Full demos are available here.

License

This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details.

Citing Bayes-Newton

@article{wilkinson2021bayesnewton,
  title = {{B}ayes-{N}ewton Methods for Approximate {B}ayesian Inference with {PSD} Guarantees},
  author = {Wilkinson, William J. and S\"arkk\"a, Simo and Solin, Arno},
  journal={arXiv preprint arXiv:2111.01721},
  year={2021}
}

Implemented Models

For a full list of the all the models available see the model class list.

Variational GPs

  • Variationl GP (Opper, Archambeau: The Variational Gaussian Approximation Revisited, Neural Computation 2009; Khan, Lin: Conugate-Computation Variational Inference - Converting Inference in Non-Conjugate Models in to Inference in Conjugate Models, AISTATS 2017)
  • Sparse Variational GP (Hensman, Matthews, Ghahramani: Scalable Variational Gaussian Process Classification, AISTATS 2015; Adam, Chang, Khan, Solin: Dual Parameterization of Sparse Variational Gaussian Processes, NeurIPS 2021)
  • Markov Variational GP (Chang, Wilkinson, Khan, Solin: Fast Variational Learning in State Space Gaussian Process Models, MLSP 2020)
  • Sparse Markov Variational GP (Adam, Eleftheriadis, Durrande, Artemev, Hensman: Doubly Sparse Variational Gaussian Processes, AISTATS 2020; Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Spatio-Temporal Variational GP (Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)

Expectation Propagation GPs

  • Expectation Propagation GP (Minka: A Family of Algorithms for Approximate Bayesian Inference, Ph. D thesis 2000)
  • Sparse Expectation Propagation GP (energy not working) (Csato, Opper: Sparse on-line Gaussian processes, Neural Computation 2002; Bui, Yan, Turner: A Unifying Framework for Gaussian Process Pseudo Point Approximations Using Power Expectation Propagation, JMLR 2017)
  • Markov Expectation Propagation GP (Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Expectation Propagation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Laplace/Newton GPs

  • Laplace GP (Rasmussen, Williams: Gaussian Processes for Machine Learning, 2006)
  • Sparse Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Markov Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Sparse Markov Laplace GP

Linearisation GPs

  • Posterior Linearisation GP (García-Fernández, Tronarp, Sarkka: Gaussian Process Classification Using Posterior Linearization, IEEE Signal Processing 2019; Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Sparse Posterior Linearisation GP
  • Markov Posterior Linearisation GP (García-Fernández, Svensson, Sarkka: Iterated Posterior Linearization Smoother, IEEE Automatic Control 2016; Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Posterior Linearisation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Taylor Expansion / Analytical Linearisaiton GP (Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Markov Taylor GP / Extended Kalman Smoother (Bell: The Iterated Kalman Smoother as a Gauss-Newton method, SIAM Journal on Optimization 1994)
  • Sparse Taylor GP
  • Sparse Markov Taylor GP / Sparse Extended Kalman Smoother (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Gauss-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Gauss-Newton
  • Variational Gauss-Newton
  • PEP Gauss-Newton
  • 2nd-order PL Gauss-Newton

Quasi-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Quasi-Newton
  • Variational Quasi-Newton
  • PEP Quasi-Newton
  • PL Quasi-Newton

GPs with PSD Constraints via Riemannian Gradients

  • VI Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • Newton/Laplace Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • PEP Riemann Grad (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

Others

  • Infinite Horizon GP (Solin, Hensman, Turner: Infinite-Horizon Gaussian Processes, NeurIPS 2018)
  • Parallel Markov GP (with VI, EP, PL, ...) (Särkkä, García-Fernández: Temporal parallelization of Bayesian smoothers; Corenflos, Zhao, Särkkä: Gaussian Process Regression in Logarithmic Time; Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)
  • 2nd-order Posterior Linearisation GP (sparse, Markov, ...) (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
Comments
  • Addition of a Squared Exponential Kernel?

    Addition of a Squared Exponential Kernel?

    This is not an issue per se, but I was wondering if there was a specific reason that there wasn't a Squared Exponential kernel as part of the package?

    If applicable, I would be happy to submit a PR adding one.

    Just let me know.

    opened by mathDR 7
  • Add sq exponential kernel

    Add sq exponential kernel

    This PR does two things:

    1. Adds a Squared Exponential Kernel. Adopting the gpflow nomenclature for naming, i.e. SquaredExponential inheriting from StationaryKernel. As of now, only the K_r method is populated (others may be populated as needed).
    2. The marathon.py demo was changed to take the SquaredExponential kernel (with lengthscale = 40). This serves as a "test" to ensure that the kernel runs.
    opened by mathDR 2
  • question about the equation (64) in `Bayes-Newton` paper

    question about the equation (64) in `Bayes-Newton` paper

    I'm a little confused by the equation (64). It is calculated by equation (63), but where is the other term in equation (64)? such as denominator comes from the covariance of p(fn|u).

    opened by Fangwq 2
  • jitted predict

    jitted predict

    Hi, I'm starting to explore your framework. I'm familiar with jax, but not with objax. I noticed that train ops are jitted with objax.Jit, but as my goal is to have fast prediction embedded in some larger jax code, I wonder if predit() can be also jitted? Thanks in advance,

    Regards,

    opened by soldierofhell 2
  • error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    As the title said, there is an error after running heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method:

    File "heteroscedastic.py", line 101, in train_op
    model.inference(lr=lr_newton, damping=damping)  # perform inference and update variational params
    File "/BayesNewton-main/bayesnewton/inference.py",  line 871, in inference
    mean, jacobian, hessian, quasi_newton_state =self.update_variational_params(batch_ind, lr, **kwargs)
    File "/BayesNewton-main/bayesnewton/inference.py",
    line 1076, in update_variational_params
    jacobian_var = transpose(solve(omega, dmu_dv)) @ residual
    ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=(117, 1, 1) and b=(117, 2, 2)
    

    Can you tell me where it is wrong ? Thanks in advance.

    opened by Fangwq 2
  • How to set initial `Pinf` variable in kernel?

    How to set initial `Pinf` variable in kernel?

    I note that the initial Pinf variable for Matern-5/2 kernel is as follows:

            Pinf = np.array([[self.variance,    0.0,   -kappa],
                             [0.0,    kappa, 0.0],
                             [-kappa, 0.0,   25.0*self.variance / self.lengthscale**4.0]])
    

    Why it is like that? Any references I should follow up?

    PS: by the way, the data ../data/aq_data.csv is missing.

    opened by Fangwq 2
  • How to install Newt in conda virtual environment?

    How to install Newt in conda virtual environment?

    Hi, thank you for sharing your great work.

    I am a little confused about how to install Newt in a conda VE. I really appreciate it if you could guide in this regard. Thank you

    opened by mohammad-saber 2
  • How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    Just as the title said, how can I understand cavity_distribution_tied in file basemodels.py? Is there any reference I should follow up? And I note that this code is similar to equation (64) in the BayesNewton paper. How does it come from?

    opened by Fangwq 1
  • issue with SparseVariationalGP method

    issue with SparseVariationalGP method

    When I run the code file demos/regression.py with SparseVariationalGP, something wrong happens:

    AssertionError: Assignments to variable must be an instance of JaxArray, but received f<class 'numpy.ndarray'>.
    

    It seems that a mistake in method SparseVariationalGP . Can you help to fix the problem? Thank you very much!

    opened by Fangwq 1
  • Sparse EP energy is incorrect

    Sparse EP energy is incorrect

    The current implementation of the sparse EP energy is not giving sensible results. This is a reminder to look into the reasons why and check against implementations elsewhere. PRs very welcome for this issue.

    Note: the EP energy is correct for all other models (GP, Markov GP, SparseMarkovGP)

    bug 
    opened by wil-j-wil 1
  • Double Precision Issues

    Double Precision Issues

    Hi!

    Many thanks for open-sourcing this package.

    I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,

    • for the Periodic kernel, it is rather sensitive to the matrix operations in _sequential_kf() and _sequential_rts().
    • Likewise, the same when the lengthscales are too large in the Matern32 kernel.

    However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.

    Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.

    Software and hardware details:

    objax==1.6.0
    jax==0.3.13
    jaxlib==0.3.10+cuda11.cudnn805
    
    NVIDIA-SMI 460.56       Driver Version: 460.56       CUDA Version: 11.2
    GeForce RTX 3090 GPUs
    

    Thanks in advance.

    Best, Harrison

    opened by harrisonzhu508 1
  • Cannot run demo, possible incompatibility with latest Jax

    Cannot run demo, possible incompatibility with latest Jax

    Dear all,

    I am trying to run the demo examples, but I run in the following error


    ImportError Traceback (most recent call last) Input In [22], in <cell line: 1>() ----> 1 import bayesnewton 2 import objax 3 import numpy as np

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in ----> 1 from . import ( 2 kernels, 3 utils, 4 ops, 5 likelihoods, 6 models, 7 basemodels, 8 inference, 9 cubature 10 ) 13 def build_model(model, inf, name='GPModel'): 14 return type(name, (inf, model), {})

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in 3 import jax.numpy as np 4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm ----> 5 from jax.ops import index_add, index 6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix 7 from warnings import warn

    ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)

    I think its related to this from the Jax website:

    The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.

    opened by daniel-trejobanos 3
  • Latest versions of JAX and objax cause compile slow down

    Latest versions of JAX and objax cause compile slow down

    It is recommended to use the following versions of jax and objax:

    jax==0.2.9
    jaxlib==0.1.60
    objax==1.3.1
    

    This is because of this objax issue which causes the model to JIT compile "twice", i.e. on the first two iterations rather than just the first. This causes a bit of a slow down for large models, but is not an problem otherwise.

    bug 
    opened by wil-j-wil 0
Releases(v1.2.0)
Owner
AaltoML
Machine learning group at Aalto University lead by Prof. Solin
AaltoML
Codes for paper "KNAS: Green Neural Architecture Search"

KNAS Codes for paper "KNAS: Green Neural Architecture Search" KNAS is a green (energy-efficient) Neural Architecture Search (NAS) approach. It contain

90 Dec 22, 2022
Toontown: Galaxy, a new Toontown game based on Disney's Toontown Online

Toontown: Galaxy The official archive repo for Toontown: Galaxy, a new Toontown

1 Feb 15, 2022
A Python module for parallel optimization of expensive black-box functions

blackbox: A Python module for parallel optimization of expensive black-box functions What is this? A minimalistic and easy-to-use Python module that e

Paul Knysh 426 Dec 08, 2022
Unsupervised Representation Learning by Invariance Propagation

Unsupervised Learning by Invariance Propagation This repository is the official implementation of Unsupervised Learning by Invariance Propagation. Pre

FengWang 15 Jul 06, 2022
git《Beta R-CNN: Looking into Pedestrian Detection from Another Perspective》(NeurIPS 2020) GitHub:[fig3]

Beta R-CNN: Looking into Pedestrian Detection from Another Perspective This is the pytorch implementation of our paper "[Beta R-CNN: Looking into Pede

35 Sep 08, 2021
Pneumonia Detection using machine learning - with PyTorch

Pneumonia Detection Pneumonia Detection using machine learning. Training was done in colab: DEMO: Result (Confusion Matrix): Data I uploaded my datase

Wilhelm Berghammer 12 Jul 07, 2022
WPPNets: Unsupervised CNN Training with Wasserstein Patch Priors for Image Superresolution

WPPNets: Unsupervised CNN Training with Wasserstein Patch Priors for Image Superresolution This code belongs to the paper [1] available at https://arx

Fabian Altekrueger 5 Jun 02, 2022
Quick program made to generate alpha and delta tables for Hidden Markov Models

HMM_Calc Functions for generating Alpha and Delta tables from a Hidden Markov Model. Parameters: a: Matrix of transition probabilities. a[i][j] = a_{i

Adem Odza 1 Dec 04, 2021
Wileless-PDGNet Implementation

Wileless-PDGNet Implementation This repo is related to the following paper: Boning Li, Ananthram Swami, and Santiago Segarra, "Power allocation for wi

6 Oct 04, 2022
Code for ICDM2020 full paper: "Sub-graph Contrast for Scalable Self-Supervised Graph Representation Learning"

Subg-Con Sub-graph Contrast for Scalable Self-Supervised Graph Representation Learning (Jiao et al., ICDM 2020): https://arxiv.org/abs/2009.10273 Over

34 Jul 06, 2022
Implementations of CNNs, RNNs, GANs, etc

Tensorflow Programs and Tutorials This repository will contain Tensorflow tutorials on a lot of the most popular deep learning concepts. It'll also co

Adit Deshpande 1k Dec 30, 2022
meProp: Sparsified Back Propagation for Accelerated Deep Learning (ICML 2017)

meProp The codes were used for the paper meProp: Sparsified Back Propagation for Accelerated Deep Learning with Reduced Overfitting (ICML 2017) [pdf]

LancoPKU 107 Nov 18, 2022
[AAAI 2022] Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding

[AAAI 2022] Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding Official Pytorch implementation of Negative Sample Matter

Multimedia Computing Group, Nanjing University 69 Dec 26, 2022
Trading Strategies for Freqtrade

Freqtrade Strategies Strategies for Freqtrade, developed primarily in a partnership between @werkkrew and @JimmyNixx from the Freqtrade Discord. Use t

Bryan Chain 242 Jan 07, 2023
DECAF: Generating Fair Synthetic Data Using Causally-Aware Generative Networks

DECAF (DEbiasing CAusal Fairness) Code Author: Trent Kyono This repository contains the code used for the "DECAF: Generating Fair Synthetic Data Using

van_der_Schaar \LAB 7 Nov 24, 2022
Unified tracking framework with a single appearance model

Paper: Do different tracking tasks require different appearance model? [ArXiv] (comming soon) [Project Page] (comming soon) UniTrack is a simple and U

ZhongdaoWang 300 Dec 24, 2022
This repository collects 100 papers related to negative sampling methods.

Negative-Sampling-Paper This repository collects 100 papers related to negative sampling methods, covering multiple research fields such as Recommenda

RUCAIBox 119 Dec 29, 2022
A PyTorch implementation of the paper "Semantic Image Synthesis via Adversarial Learning" in ICCV 2017

Semantic Image Synthesis via Adversarial Learning This is a PyTorch implementation of the paper Semantic Image Synthesis via Adversarial Learning. Req

Seonghyeon Nam 146 Nov 25, 2022
Hyperbolic Hierarchical Clustering.

Hyperbolic Hierarchical Clustering (HypHC) This code is the official PyTorch implementation of the NeurIPS 2020 paper: From Trees to Continuous Embedd

HazyResearch 154 Dec 15, 2022
Python library for tracking human heads with FLAME (a 3D morphable head model)

Video Head Tracker 3D tracking library for human heads based on FLAME (a 3D morphable head model). The tracking algorithm is inspired by face2face. It

61 Dec 25, 2022