tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
FS-Mol: A Few-Shot Learning Dataset of Molecules

FS-Mol is A Few-Shot Learning Dataset of Molecules, containing molecular compounds with measurements of activity against a variety of protein targets. The dataset is presented with a model evaluation

Microsoft 114 Dec 15, 2022
Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21)

AdvRush Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21) Environmental Set-up Python == 3.6.12, PyTorch =

11 Dec 10, 2022
A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

CapsNet-Tensorflow A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules Notes: The current version

Huadong Liao 3.8k Dec 29, 2022
Codes for building and training the neural network model described in Domain-informed neural networks for interaction localization within astroparticle experiments.

Domain-informed Neural Networks Codes for building and training the neural network model described in Domain-informed neural networks for interaction

DIDACTS 0 Dec 13, 2021
A very simple tool to rewrite parameters such as attributes and constants for OPs in ONNX models. Simple Attribute and Constant Modifier for ONNX.

sam4onnx A very simple tool to rewrite parameters such as attributes and constants for OPs in ONNX models. Simple Attribute and Constant Modifier for

Katsuya Hyodo 6 May 15, 2022
Project page for End-to-end Recovery of Human Shape and Pose

End-to-end Recovery of Human Shape and Pose Angjoo Kanazawa, Michael J. Black, David W. Jacobs, Jitendra Malik CVPR 2018 Project Page Requirements Pyt

1.4k Dec 29, 2022
PyTorch code for our ECCV 2020 paper "Single Image Super-Resolution via a Holistic Attention Network"

HAN PyTorch code for our ECCV 2020 paper "Single Image Super-Resolution via a Holistic Attention Network" This repository is for HAN introduced in the

五维空间 140 Nov 23, 2022
NL-Augmenter 🦎 → 🐍 A Collaborative Repository of Natural Language Transformations

NL-Augmenter 🦎 → 🐍 The NL-Augmenter is a collaborative effort intended to add transformations of datasets dealing with natural language. Transformat

684 Jan 09, 2023
CAUSE: Causality from AttribUtions on Sequence of Events

CAUSE: Causality from AttribUtions on Sequence of Events

Wei Zhang 21 Dec 01, 2022
Yolov5 deepsort inference,使用YOLOv5+Deepsort实现车辆行人追踪和计数,代码封装成一个Detector类,更容易嵌入到自己的项目中

使用YOLOv5+Deepsort实现车辆行人追踪和计数,代码封装成一个Detector类,更容易嵌入到自己的项目中。

813 Dec 31, 2022
This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive Selective Coding)

HCSC: Hierarchical Contrastive Selective Coding This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive

YUANFAN GUO 111 Dec 20, 2022
DeLag: Detecting Latency Degradation Patterns in Service-based Systems

DeLag: Detecting Latency Degradation Patterns in Service-based Systems Replication package of the work "DeLag: Detecting Latency Degradation Patterns

SEALABQualityGroup @ University of L'Aquila 2 Mar 24, 2022
A denoising diffusion probabilistic model (DDPM) tailored for conditional generation of protein distograms

Denoising Diffusion Probabilistic Model for Proteins Implementation of Denoising Diffusion Probabilistic Model in Pytorch. It is a new approach to gen

Phil Wang 108 Nov 23, 2022
Interpretation of T cell states using reference single-cell atlases

Interpretation of T cell states using reference single-cell atlases ProjecTILs is a computational method to project scRNA-seq data into reference sing

Cancer Systems Immunology Lab 139 Jan 03, 2023
Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.

TFlite Ultra Fast Lane Detection Inference Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite. So

Ibai Gorordo 12 Aug 27, 2022
李云龙二次元风格化!打滚卖萌,使用了animeGANv2进行了视频的风格迁移

李云龙二次元风格化!一键star、fork,你也可以生成这样的团长! 打滚卖萌求star求fork! 0.效果展示 视频效果前往B站观看效果最佳:李云龙二次元风格化: github开源repo:李云龙二次元风格化 百度AIstudio开源地址,一键fork即可运行: 李云龙二次元风格化!一键fork

oukohou 44 Dec 04, 2022
Dist2Dec: A Simplicial Neural Network for Homology Localization

Dist2Dec: A Simplicial Neural Network for Homology Localization

Alexandros Keros 6 Jun 12, 2022
The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Dice Loss for NLP Tasks This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020. Setup Install Package Dependencies The c

223 Dec 17, 2022
Data augmentation for NLP, accepted at EMNLP 2021 Findings

AEDA: An Easier Data Augmentation Technique for Text Classification This is the code for the EMNLP 2021 paper AEDA: An Easier Data Augmentation Techni

Akbar Karimi 81 Dec 09, 2022
Official implement of Paper:A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sening images

A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sensing images 深度监督影像融合网络DSIFN用于高分辨率双时相遥感影像变化检测 Of

Chenxiao Zhang 135 Dec 19, 2022