tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
Easily Process a Batch of Cox Models

ezcox: Easily Process a Batch of Cox Models The goal of ezcox is to operate a batch of univariate or multivariate Cox models and return tidy result. ⏬

Shixiang Wang 15 May 23, 2022
A Shading-Guided Generative Implicit Model for Shape-Accurate 3D-Aware Image Synthesis

A Shading-Guided Generative Implicit Model for Shape-Accurate 3D-Aware Image Synthesis Figure: Shape-Accurate 3D-Aware Image Synthesis. A Shading-Guid

Xingang Pan 115 Dec 18, 2022
Yoloxkeypointsegment - An anchor-free version of YOLO, with a simpler design but better performance

Introduction 关键点版本:已完成 全景分割版本:已完成 实例分割版本:已完成 YOLOX is an anchor-free version of

23 Oct 20, 2022
Natural Intelligence is still a pretty good idea.

Human Learn Machine Learning models should play by the rules, literally. Project Goal Back in the old days, it was common to write rule-based systems.

vincent d warmerdam 641 Dec 26, 2022
Generative Handwriting using LSTM Mixture Density Network with TensorFlow

Generative Handwriting Demo using TensorFlow An attempt to implement the random handwriting generation portion of Alex Graves' paper. See my blog post

hardmaru 686 Nov 24, 2022
An Open-Source Tool for Automatic Disease Diagnosis..

OpenMedicalChatbox An Open-Source Package for Automatic Disease Diagnosis. Overview Due to the lack of open source for existing RL-base automated diag

8 Nov 08, 2022
PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)

English | 简体中文 Welcome to the PaddlePaddle GitHub. PaddlePaddle, as the only independent R&D deep learning platform in China, has been officially open

19.4k Jan 04, 2023
Fedlearn支持前沿算法研发的Python工具库 | Fedlearn algorithm toolkit for researchers

FedLearn-algo Installation Development Environment Checklist python3 (3.6 or 3.7) is required. To configure and check the development environment is c

89 Nov 14, 2022
Python library for analysis of time series data including dimensionality reduction, clustering, and Markov model estimation

deeptime Releases: Installation via conda recommended. conda install -c conda-forge deeptime pip install deeptime Documentation: deeptime-ml.github.io

495 Dec 28, 2022
Unofficial Tensorflow-Keras implementation of Fastformer based on paper [Fastformer: Additive Attention Can Be All You Need](https://arxiv.org/abs/2108.09084).

Fastformer-Keras Unofficial Tensorflow-Keras implementation of Fastformer based on paper Fastformer: Additive Attention Can Be All You Need. Tensorflo

Yam Peleg 10 Jan 30, 2022
Package for working with hypernetworks in PyTorch.

Package for working with hypernetworks in PyTorch.

Christian Henning 71 Jan 05, 2023
RTS3D: Real-time Stereo 3D Detection from 4D Feature-Consistency Embedding Space for Autonomous Driving

RTS3D: Real-time Stereo 3D Detection from 4D Feature-Consistency Embedding Space for Autonomous Driving (AAAI2021). RTS3D is efficiency and accuracy s

71 Nov 29, 2022
The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer"

Shuffle Transformer The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer" Introduction Very recently, window-

87 Nov 29, 2022
Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021

Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021 Global Pooling, More than Meets the Eye: Posi

Md Amirul Islam 32 Apr 24, 2022
DCGAN-tensorflow - A tensorflow implementation of Deep Convolutional Generative Adversarial Networks

DCGAN in Tensorflow Tensorflow implementation of Deep Convolutional Generative Adversarial Networks which is a stabilize Generative Adversarial Networ

Taehoon Kim 7.1k Dec 29, 2022
S2-BNN: Bridging the Gap Between Self-Supervised Real and 1-bit Neural Networks via Guided Distribution Calibration (CVPR 2021)

S2-BNN (Self-supervised Binary Neural Networks Using Distillation Loss) This is the official pytorch implementation of our paper: "S2-BNN: Bridging th

Zhiqiang Shen 52 Dec 24, 2022
Image Recognition using Pytorch

PyTorch Project Template A simple and well designed structure is essential for any Deep Learning project, so after a lot practice and contributing in

Sarat Chinni 1 Nov 02, 2021
Official implementation of YOGO for Point-Cloud Processing

You Only Group Once: Efficient Point-Cloud Processing with Token Representation and Relation Inference Module By Chenfeng Xu, Bohan Zhai, Bichen Wu, T

Chenfeng Xu 67 Dec 20, 2022
Repositório da disciplina de APC, no segundo semestre de 2021

NOTAS FINAIS: https://github.com/fabiommendes/apc2018/blob/master/nota-final.pdf Algoritmos e Programação de Computadores Este é o Git da disciplina A

16 Dec 16, 2022
code for paper -- "Seamless Satellite-image Synthesis"

Seamless Satellite-image Synthesis by Jialin Zhu and Tom Kelly. Project site. The code of our models borrows heavily from the BicycleGAN repository an

Light 14 Apr 05, 2022