tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
Employee-Managment - Company employee registration software in the face recognition system

Employee-Managment Company employee registration software in the face recognitio

Alireza Kiaeipour 7 Jul 10, 2022
Unsupervised clustering of high content screen samples

Microscopium Unsupervised clustering and dataset exploration for high content screens. See microscopium in action Public dataset BBBC021 from the Broa

60 Dec 05, 2022
Experiments and examples converting Transformers to ONNX

Experiments and examples converting Transformers to ONNX This repository containes experiments and examples on converting different Transformers to ON

Philipp Schmid 4 Dec 24, 2022
HiPAL: A Deep Framework for Physician Burnout Prediction Using Activity Logs in Electronic Health Records

HiPAL Code for KDD'22 Applied Data Science Track submission -- HiPAL: A Deep Framework for Physician Burnout Prediction Using Activity Logs in Electro

Hanyang Liu 4 Aug 08, 2022
Generative code template for PixelBeasts 10k NFT project.

generator-template Generative code template for combining transparent png attributes into 10,000 unique images. Used for the PixelBeasts 10k NFT proje

Yohei Nakajima 9 Aug 24, 2022
[NeurIPS2021] Code Release of Learning Transferable Perturbations

Learning Transferable Adversarial Perturbations This is an official release of the paper Learning Transferable Adversarial Perturbations. The code is

Krishna Kanth 17 Nov 11, 2022
fcn by tensorflow

Update An example on how to integrate this code into your own semantic segmentation pipeline can be found in my KittiSeg project repository. tensorflo

9 May 22, 2022
The implementation of FOLD-R++ algorithm

FOLD-R-PP The implementation of FOLD-R++ algorithm. The target of FOLD-R++ algorithm is to learn an answer set program for a classification task. Inst

13 Dec 23, 2022
The code for our CVPR paper PISE: Person Image Synthesis and Editing with Decoupled GAN, Project Page, supp.

PISE The code for our CVPR paper PISE: Person Image Synthesis and Editing with Decoupled GAN, Project Page, supp. Requirement conda create -n pise pyt

jinszhang 110 Nov 21, 2022
Predicting Price of house by considering ,house age, Distance from public transport

House-Price-Prediction Predicting Price of house by considering ,house age, Distance from public transport, No of convenient stores around house etc..

Musab Jaleel 1 Jan 08, 2022
A Partition Filter Network for Joint Entity and Relation Extraction EMNLP 2021

EMNLP 2021 - A Partition Filter Network for Joint Entity and Relation Extraction

zhy 127 Jan 04, 2023
Real-time Object Detection for Streaming Perception, CVPR 2022

StreamYOLO Real-time Object Detection for Streaming Perception Jinrong Yang, Songtao Liu, Zeming Li, Xiaoping Li, Sun Jian Real-time Object Detection

Jinrong Yang 237 Dec 27, 2022
FS2KToolbox FS2K Dataset Towards the translation between Face

FS2KToolbox FS2K Dataset Towards the translation between Face -- Sketch. Download (photo+sketch+annotation): Google-drive, Baidu-disk, pw: FS2K. For

Deng-Ping Fan 5 Jan 03, 2023
Exploring Image Deblurring via Blur Kernel Space (CVPR'21)

Exploring Image Deblurring via Encoded Blur Kernel Space About the project We introduce a method to encode the blur operators of an arbitrary dataset

VinAI Research 118 Dec 19, 2022
Hippocampal segmentation using the UNet network for each axis

Hipposeg Hippocampal segmentation using the UNet network for each axis, inspired by https://github.com/MICLab-Unicamp/e2dhipseg Red: False Positive Gr

Juan Carlos Aguirre Arango 0 Sep 02, 2021
This repository is for Competition for ML_data class

This repository is for Competition for ML_data class. Based on mmsegmentatoin,mainly using swin transformer to completed the competition.

jianlong 2 Oct 23, 2022
Implicit Graph Neural Networks

Implicit Graph Neural Networks This repository is the official PyTorch implementation of "Implicit Graph Neural Networks". Fangda Gu*, Heng Chang*, We

Heng Chang 48 Nov 29, 2022
OpenMMLab Pose Estimation Toolbox and Benchmark.

Introduction English | 简体中文 MMPose is an open-source toolbox for pose estimation based on PyTorch. It is a part of the OpenMMLab project. The master b

OpenMMLab 2.8k Dec 31, 2022
Locationinfo - A script helps the user to show network information such as ip address

Description This script helps the user to show network information such as ip ad

Roxcoder 1 Dec 30, 2021
Hashformers is a framework for hashtag segmentation with transformers.

Hashtag segmentation is the task of automatically inserting the missing spaces between the words in a hashtag. Hashformers applies Transformer models

Ruan Chaves 41 Nov 09, 2022