Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Overview

Bayes-Newton

Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Will Wilkinson.

Bayes-Newton provides a unifying view of approximate Bayesian inference, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page.

Installation

pip install bayesnewton

Example

Given some inputs x and some data y, you can construct a Bayes-Newton model as follows,

kern = bayesnewton.kernels.Matern52()
lik = bayesnewton.likelihoods.Gaussian()
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)

The training loop (inference and hyperparameter learning) is then set up using objax's built in functionality:

lr_adam = 0.1
lr_newton = 1
opt_hypers = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())

@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
    model.inference(lr=lr_newton, **inf_args)  # perform inference and update variational params
    dE, E = energy(**inf_args)  # compute energy and its gradients w.r.t. hypers
    opt_hypers(lr_adam, dE)  # update the hyperparameters
    return E

As we are using JAX, we can JIT compile the training loop:

train_op = objax.Jit(train_op)

and then run the training loop,

iters = 20
for i in range(1, iters + 1):
    loss = train_op()

Full demos are available here.

License

This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details.

Citing Bayes-Newton

@article{wilkinson2021bayesnewton,
  title = {{B}ayes-{N}ewton Methods for Approximate {B}ayesian Inference with {PSD} Guarantees},
  author = {Wilkinson, William J. and S\"arkk\"a, Simo and Solin, Arno},
  journal={arXiv preprint arXiv:2111.01721},
  year={2021}
}

Implemented Models

For a full list of the all the models available see the model class list.

Variational GPs

  • Variationl GP (Opper, Archambeau: The Variational Gaussian Approximation Revisited, Neural Computation 2009; Khan, Lin: Conugate-Computation Variational Inference - Converting Inference in Non-Conjugate Models in to Inference in Conjugate Models, AISTATS 2017)
  • Sparse Variational GP (Hensman, Matthews, Ghahramani: Scalable Variational Gaussian Process Classification, AISTATS 2015; Adam, Chang, Khan, Solin: Dual Parameterization of Sparse Variational Gaussian Processes, NeurIPS 2021)
  • Markov Variational GP (Chang, Wilkinson, Khan, Solin: Fast Variational Learning in State Space Gaussian Process Models, MLSP 2020)
  • Sparse Markov Variational GP (Adam, Eleftheriadis, Durrande, Artemev, Hensman: Doubly Sparse Variational Gaussian Processes, AISTATS 2020; Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Spatio-Temporal Variational GP (Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)

Expectation Propagation GPs

  • Expectation Propagation GP (Minka: A Family of Algorithms for Approximate Bayesian Inference, Ph. D thesis 2000)
  • Sparse Expectation Propagation GP (energy not working) (Csato, Opper: Sparse on-line Gaussian processes, Neural Computation 2002; Bui, Yan, Turner: A Unifying Framework for Gaussian Process Pseudo Point Approximations Using Power Expectation Propagation, JMLR 2017)
  • Markov Expectation Propagation GP (Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Expectation Propagation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Laplace/Newton GPs

  • Laplace GP (Rasmussen, Williams: Gaussian Processes for Machine Learning, 2006)
  • Sparse Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Markov Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Sparse Markov Laplace GP

Linearisation GPs

  • Posterior Linearisation GP (García-Fernández, Tronarp, Sarkka: Gaussian Process Classification Using Posterior Linearization, IEEE Signal Processing 2019; Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Sparse Posterior Linearisation GP
  • Markov Posterior Linearisation GP (García-Fernández, Svensson, Sarkka: Iterated Posterior Linearization Smoother, IEEE Automatic Control 2016; Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Posterior Linearisation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Taylor Expansion / Analytical Linearisaiton GP (Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Markov Taylor GP / Extended Kalman Smoother (Bell: The Iterated Kalman Smoother as a Gauss-Newton method, SIAM Journal on Optimization 1994)
  • Sparse Taylor GP
  • Sparse Markov Taylor GP / Sparse Extended Kalman Smoother (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Gauss-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Gauss-Newton
  • Variational Gauss-Newton
  • PEP Gauss-Newton
  • 2nd-order PL Gauss-Newton

Quasi-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Quasi-Newton
  • Variational Quasi-Newton
  • PEP Quasi-Newton
  • PL Quasi-Newton

GPs with PSD Constraints via Riemannian Gradients

  • VI Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • Newton/Laplace Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • PEP Riemann Grad (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

Others

  • Infinite Horizon GP (Solin, Hensman, Turner: Infinite-Horizon Gaussian Processes, NeurIPS 2018)
  • Parallel Markov GP (with VI, EP, PL, ...) (Särkkä, García-Fernández: Temporal parallelization of Bayesian smoothers; Corenflos, Zhao, Särkkä: Gaussian Process Regression in Logarithmic Time; Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)
  • 2nd-order Posterior Linearisation GP (sparse, Markov, ...) (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
Comments
  • Addition of a Squared Exponential Kernel?

    Addition of a Squared Exponential Kernel?

    This is not an issue per se, but I was wondering if there was a specific reason that there wasn't a Squared Exponential kernel as part of the package?

    If applicable, I would be happy to submit a PR adding one.

    Just let me know.

    opened by mathDR 7
  • Add sq exponential kernel

    Add sq exponential kernel

    This PR does two things:

    1. Adds a Squared Exponential Kernel. Adopting the gpflow nomenclature for naming, i.e. SquaredExponential inheriting from StationaryKernel. As of now, only the K_r method is populated (others may be populated as needed).
    2. The marathon.py demo was changed to take the SquaredExponential kernel (with lengthscale = 40). This serves as a "test" to ensure that the kernel runs.
    opened by mathDR 2
  • question about the equation (64) in `Bayes-Newton` paper

    question about the equation (64) in `Bayes-Newton` paper

    I'm a little confused by the equation (64). It is calculated by equation (63), but where is the other term in equation (64)? such as denominator comes from the covariance of p(fn|u).

    opened by Fangwq 2
  • jitted predict

    jitted predict

    Hi, I'm starting to explore your framework. I'm familiar with jax, but not with objax. I noticed that train ops are jitted with objax.Jit, but as my goal is to have fast prediction embedded in some larger jax code, I wonder if predit() can be also jitted? Thanks in advance,

    Regards,

    opened by soldierofhell 2
  • error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    As the title said, there is an error after running heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method:

    File "heteroscedastic.py", line 101, in train_op
    model.inference(lr=lr_newton, damping=damping)  # perform inference and update variational params
    File "/BayesNewton-main/bayesnewton/inference.py",  line 871, in inference
    mean, jacobian, hessian, quasi_newton_state =self.update_variational_params(batch_ind, lr, **kwargs)
    File "/BayesNewton-main/bayesnewton/inference.py",
    line 1076, in update_variational_params
    jacobian_var = transpose(solve(omega, dmu_dv)) @ residual
    ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=(117, 1, 1) and b=(117, 2, 2)
    

    Can you tell me where it is wrong ? Thanks in advance.

    opened by Fangwq 2
  • How to set initial `Pinf` variable in kernel?

    How to set initial `Pinf` variable in kernel?

    I note that the initial Pinf variable for Matern-5/2 kernel is as follows:

            Pinf = np.array([[self.variance,    0.0,   -kappa],
                             [0.0,    kappa, 0.0],
                             [-kappa, 0.0,   25.0*self.variance / self.lengthscale**4.0]])
    

    Why it is like that? Any references I should follow up?

    PS: by the way, the data ../data/aq_data.csv is missing.

    opened by Fangwq 2
  • How to install Newt in conda virtual environment?

    How to install Newt in conda virtual environment?

    Hi, thank you for sharing your great work.

    I am a little confused about how to install Newt in a conda VE. I really appreciate it if you could guide in this regard. Thank you

    opened by mohammad-saber 2
  • How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    Just as the title said, how can I understand cavity_distribution_tied in file basemodels.py? Is there any reference I should follow up? And I note that this code is similar to equation (64) in the BayesNewton paper. How does it come from?

    opened by Fangwq 1
  • issue with SparseVariationalGP method

    issue with SparseVariationalGP method

    When I run the code file demos/regression.py with SparseVariationalGP, something wrong happens:

    AssertionError: Assignments to variable must be an instance of JaxArray, but received f<class 'numpy.ndarray'>.
    

    It seems that a mistake in method SparseVariationalGP . Can you help to fix the problem? Thank you very much!

    opened by Fangwq 1
  • Sparse EP energy is incorrect

    Sparse EP energy is incorrect

    The current implementation of the sparse EP energy is not giving sensible results. This is a reminder to look into the reasons why and check against implementations elsewhere. PRs very welcome for this issue.

    Note: the EP energy is correct for all other models (GP, Markov GP, SparseMarkovGP)

    bug 
    opened by wil-j-wil 1
  • Double Precision Issues

    Double Precision Issues

    Hi!

    Many thanks for open-sourcing this package.

    I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,

    • for the Periodic kernel, it is rather sensitive to the matrix operations in _sequential_kf() and _sequential_rts().
    • Likewise, the same when the lengthscales are too large in the Matern32 kernel.

    However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.

    Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.

    Software and hardware details:

    objax==1.6.0
    jax==0.3.13
    jaxlib==0.3.10+cuda11.cudnn805
    
    NVIDIA-SMI 460.56       Driver Version: 460.56       CUDA Version: 11.2
    GeForce RTX 3090 GPUs
    

    Thanks in advance.

    Best, Harrison

    opened by harrisonzhu508 1
  • Cannot run demo, possible incompatibility with latest Jax

    Cannot run demo, possible incompatibility with latest Jax

    Dear all,

    I am trying to run the demo examples, but I run in the following error


    ImportError Traceback (most recent call last) Input In [22], in <cell line: 1>() ----> 1 import bayesnewton 2 import objax 3 import numpy as np

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in ----> 1 from . import ( 2 kernels, 3 utils, 4 ops, 5 likelihoods, 6 models, 7 basemodels, 8 inference, 9 cubature 10 ) 13 def build_model(model, inf, name='GPModel'): 14 return type(name, (inf, model), {})

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in 3 import jax.numpy as np 4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm ----> 5 from jax.ops import index_add, index 6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix 7 from warnings import warn

    ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)

    I think its related to this from the Jax website:

    The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.

    opened by daniel-trejobanos 3
  • Latest versions of JAX and objax cause compile slow down

    Latest versions of JAX and objax cause compile slow down

    It is recommended to use the following versions of jax and objax:

    jax==0.2.9
    jaxlib==0.1.60
    objax==1.3.1
    

    This is because of this objax issue which causes the model to JIT compile "twice", i.e. on the first two iterations rather than just the first. This causes a bit of a slow down for large models, but is not an problem otherwise.

    bug 
    opened by wil-j-wil 0
Releases(v1.2.0)
Owner
AaltoML
Machine learning group at Aalto University lead by Prof. Solin
AaltoML
Checking fibonacci - Generating the Fibonacci sequence is a classic recursive problem

Fibonaaci Series Generating the Fibonacci sequence is a classic recursive proble

Moureen Caroline O 1 Feb 15, 2022
World Models with TensorFlow 2

World Models This repo reproduces the original implementation of World Models. This implementation uses TensorFlow 2.2. Docker The easiest way to hand

Zac Wellmer 234 Nov 30, 2022
Safe Local Motion Planning with Self-Supervised Freespace Forecasting, CVPR 2021

Safe Local Motion Planning with Self-Supervised Freespace Forecasting By Peiyun Hu, Aaron Huang, John Dolan, David Held, and Deva Ramanan Citing us Yo

Peiyun Hu 90 Dec 01, 2022
Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR)

This is the official implementation of our paper Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR), which has been accepted by WSDM2022.

Yongchun Zhu 81 Dec 29, 2022
Depth-Aware Video Frame Interpolation (CVPR 2019)

DAIN (Depth-Aware Video Frame Interpolation) Project | Paper Wenbo Bao, Wei-Sheng Lai, Chao Ma, Xiaoyun Zhang, Zhiyong Gao, and Ming-Hsuan Yang IEEE C

Wenbo Bao 7.7k Dec 31, 2022
PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos

PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos. By adopting a unified pipeline-ba

PyKale 370 Dec 27, 2022
Study of human inductive biases in CNNs and Transformers.

Are Convolutional Neural Networks or Transformers more like human vision? This repository contains the code and fine-tuned models of popular Convoluti

Shikhar Tuli 39 Dec 08, 2022
PyTorch implementation of neural style randomization for data augmentation

README Augment training images for deep neural networks by randomizing their visual style, as described in our paper: https://arxiv.org/abs/1809.05375

84 Nov 23, 2022
Continual learning with sketched Jacobian approximations

Continual learning with sketched Jacobian approximations This repository contains the code for reproducing figures and results in the paper ``Provable

Machine Learning and Information Processing Laboratory 1 Jun 30, 2022
(ICCV 2021 Oral) Re-distributing Biased Pseudo Labels for Semi-supervised Semantic Segmentation: A Baseline Investigation.

DARS Code release for the paper "Re-distributing Biased Pseudo Labels for Semi-supervised Semantic Segmentation: A Baseline Investigation", ICCV 2021

CVMI Lab 58 Jan 01, 2023
Fast and accurate optimisation for registration with little learningconvexadam

convexAdam Learn2Reg 2021 Submission Fast and accurate optimisation for registration with little learning Excellent results on Learn2Reg 2021 challeng

17 Dec 06, 2022
Text to image synthesis using thought vectors

Text To Image Synthesis Using Thought Vectors This is an experimental tensorflow implementation of synthesizing images from captions using Skip Though

Paarth Neekhara 2.1k Jan 05, 2023
The implementation of DeBERTa

DeBERTa: Decoding-enhanced BERT with Disentangled Attention This repository is the official implementation of DeBERTa: Decoding-enhanced BERT with Dis

Microsoft 1.2k Jan 06, 2023
Res2Net for Instance segmentation and Object detection using MaskRCNN

Res2Net for Instance segmentation and Object detection using MaskRCNN Since the MaskRCNN-benchmark of facebook is deprecated, we suggest to use our mm

Res2Net Applications 55 Oct 30, 2022
Syed Waqas Zamir 906 Dec 30, 2022
Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation"

DSP Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation". Accepted by ACM Multimedia 2021. Authors

20 Oct 24, 2022
Differentiable Annealed Importance Sampling (DAIS)

Differentiable Annealed Importance Sampling (DAIS) This repository contains the code to reproduce the DAIS results from the paper Differentiable Annea

Guodong Zhang 6 Dec 26, 2021
Neural Nano-Optics for High-quality Thin Lens Imaging

Neural Nano-Optics for High-quality Thin Lens Imaging Project Page | Paper | Data Ethan Tseng, Shane Colburn, James Whitehead, Luocheng Huang, Seung-H

Ethan Tseng 39 Dec 05, 2022
MLSpace: Hassle-free machine learning & deep learning development

MLSpace: Hassle-free machine learning & deep learning development

abhishek thakur 293 Jan 03, 2023
This is the source code for: Context-aware Entity Typing in Knowledge Graphs.

This is the source code for: Context-aware Entity Typing in Knowledge Graphs.

9 Sep 01, 2022