tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
“Data Augmentation for Cross-Domain Named Entity Recognition” (EMNLP 2021)

Data Augmentation for Cross-Domain Named Entity Recognition Authors: Shuguang Chen, Gustavo Aguilar, Leonardo Neves and Thamar Solorio This repository

<a href=[email protected]"> 18 Sep 10, 2022
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayes

Intel Labs 210 Jan 04, 2023
Predicting 10 different clothing types using Xception pre-trained model.

Predicting-Clothing-Types Predicting 10 different clothing types using Xception pre-trained model from Keras library. It is reimplemented version from

AbdAssalam Ahmad 3 Dec 29, 2021
MediaPipeのPythonパッケージのサンプルです。2020/12/11時点でPython実装のある4機能(Hands、Pose、Face Mesh、Holistic)について用意しています。

mediapipe-python-sample MediaPipeのPythonパッケージのサンプルです。 2020/12/11時点でPython実装のある以下4機能について用意しています。 Hands Pose Face Mesh Holistic Requirement mediapipe 0.

KazuhitoTakahashi 217 Dec 12, 2022
DeepLab is a state-of-art deep learning system for semantic image segmentation built on top of Caffe.

DeepLab Introduction DeepLab is a state-of-art deep learning system for semantic image segmentation built on top of Caffe. It combines densely-compute

Ali 234 Nov 14, 2022
This repo is to be freely used by ML devs to check the GAN performances without coding from scratch.

GANs for Fun Created because I can! GOAL The goal of this repo is to be freely used by ML devs to check the GAN performances without coding from scrat

Sagnik Roy 13 Jan 26, 2022
A Jupyter notebook to play with NVIDIA's StyleGAN3 and OpenAI's CLIP for a text-based guided image generation.

A Jupyter notebook to play with NVIDIA's StyleGAN3 and OpenAI's CLIP for a text-based guided image generation.

Eugenio Herrera 175 Dec 29, 2022
StyleGAN2 - Official TensorFlow Implementation

StyleGAN2 - Official TensorFlow Implementation

NVIDIA Research Projects 10.1k Dec 28, 2022
QueryInst: Parallelly Supervised Mask Query for Instance Segmentation

QueryInst is a simple and effective query based instance segmentation method driven by parallel supervision on dynamic mask heads, which outperforms previous arts in terms of both accuracy and speed.

Hust Visual Learning Team 386 Jan 08, 2023
Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research

Megaverse Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research. The efficient design of the engine enables ph

Aleksei Petrenko 191 Dec 23, 2022
Cycle Consistent Adversarial Domain Adaptation (CyCADA)

Cycle Consistent Adversarial Domain Adaptation (CyCADA) A pytorch implementation of CyCADA. If you use this code in your research please consider citi

Hyunwoo Ko 2 Jan 10, 2022
An implementation of a discriminant function over a normal distribution to help classify datasets.

CS4044D Machine Learning Assignment 1 By Dev Sony, B180297CS The question, report and source code can be found here. Github Repo Solution 1 Based on t

Dev Sony 6 Nov 09, 2021
Official implementation of "StyleCariGAN: Caricature Generation via StyleGAN Feature Map Modulation" (SIGGRAPH 2021)

StyleCariGAN in PyTorch Official implementation of StyleCariGAN:Caricature Generation via StyleGAN Feature Map Modulation in PyTorch Requirements PyTo

PeterZhouSZ 49 Oct 31, 2022
"Exploring Vision Transformers for Fine-grained Classification" at CVPRW FGVC8

FGVC8 Exploring Vision Transformers for Fine-grained Classification paper presented at the CVPR 2021, The Eight Workshop on Fine-Grained Visual Catego

Marcos V. Conde 19 Dec 06, 2022
Deep Learning Interviews book: Hundreds of fully solved job interview questions from a wide range of key topics in AI.

This book was written for you: an aspiring data scientist with a quantitative background, facing down the gauntlet of the interview process in an increasingly competitive field. For most of you, the

4.1k Dec 28, 2022
Learning to Reach Goals via Iterated Supervised Learning

Vanilla GCSL This repository contains a vanilla implementation of "Learning to Reach Goals via Iterated Supervised Learning" proposed by Dibya Gosh et

Christoph Heindl 4 Aug 10, 2022
Alpha-IoU: A Family of Power Intersection over Union Losses for Bounding Box Regression

Alpha-IoU: A Family of Power Intersection over Union Losses for Bounding Box Regression YOLOv5 with alpha-IoU losses implemented in PyTorch. Example r

Jacobi(Jiabo He) 147 Dec 05, 2022
The VeriNet toolkit for verification of neural networks

VeriNet The VeriNet toolkit is a state-of-the-art sound and complete symbolic interval propagation based toolkit for verification of neural networks.

9 Dec 21, 2022
Official implementation for (Show, Attend and Distill: Knowledge Distillation via Attention-based Feature Matching, AAAI-2021)

Show, Attend and Distill: Knowledge Distillation via Attention-based Feature Matching Official pytorch implementation of "Show, Attend and Distill: Kn

Clova AI Research 80 Dec 16, 2022
The story of Chicken for Club Bing

Chicken Story tl;dr: The time when Microsoft banned my entire country for cheating at Club Bing. (A lot of the details are from memory so I've recreat

Eyal 142 May 16, 2022