A small library for creating and manipulating custom JAX Pytree classes

Overview

Treeo

A small library for creating and manipulating custom JAX Pytree classes

  • Light-weight: has no dependencies other than jax.
  • Compatible: Treeo Tree objects are compatible with any jax function that accepts Pytrees.
  • Standards-based: treeo.field is built on top of python's dataclasses.field.
  • Flexible: Treeo is compatible with both dataclass and non-dataclass classes.

Treeo lets you easily create class-based Pytrees so your custom objects can easily interact seamlessly with JAX. Uses of Treeo can range from just creating simple simple JAX-aware utility classes to using it as the core abstraction for full-blown frameworks. Treeo was originally extracted from the core of Treex and shares a lot in common with flax.struct.

Documentation | User Guide

Installation

Install using pip:

pip install treeo

Basics

With Treeo you can easily define your own custom Pytree classes by inheriting from Treeo's Tree class and using the field function to declare which fields are nodes (children) and which are static (metadata):

import treeo as to

@dataclass
class Person(to.Tree):
    height: jnp.array = to.field(node=True) # I am a node field!
    name: str = to.field(node=False) # I am a static field!

field is just a wrapper around dataclasses.field so you can define your Pytrees as dataclasses, but Treeo fully supports non-dataclass classes as well. Since all Tree instances are Pytree they work with the various functions from thejax library as expected:

p = Person(height=jnp.array(1.8), name="John")

# Trees can be jitted!
jax.jit(lambda person: person)(p) # Person(height=array(1.8), name='John')

# Trees can be mapped!
jax.tree_map(lambda x: 2 * x, p) # Person(height=array(3.6), name='John')

Kinds

Treeo also include a kind system that lets you give semantic meaning to fields (what a field represents within your application). A kind is just a type you pass to field via its kind argument:

class Parameter: pass
class BatchStat: pass

class BatchNorm(to.Tree):
    scale: jnp.ndarray = to.field(node=True, kind=Parameter)
    mean: jnp.ndarray = to.field(node=True, kind=BatchStat)

Kinds are very useful as a filtering mechanism via treeo.filter:

model = BatchNorm(...)

# select only Parameters, mean is filtered out
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)

Nothing behaves like None in Python, but it is a special value that is used to represent the absence of a value within Treeo.

Treeo also offers the merge function which lets you rejoin filtered Trees with a logic similar to Python dict.update but done recursively:

def loss_fn(params, model, ...):
    # add traced params to model
    model = to.merge(model, params)
    ...

# gradient only w.r.t. params
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)
grads = jax.grad(loss_fn)(params, model, ...)

For a more in-depth tour check out the User Guide.

Examples

A simple Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Character(to.Tree):
    position: jnp.ndarray = to.field(node=True)    # node field
    name: str = to.field(node=False, opaque=True)  # static field

character = Character(position=jnp.array([0, 0]), name='Adam')

# character can freely pass through jit
@jax.jit
def update(character: Character, velocity, dt) -> Character:
    character.position += velocity * dt
    return character

character = update(character velocity=jnp.array([1.0, 0.2]), dt=0.1)

A Stateful Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Counter(to.Tree):
    n: jnp.array = to.field(default=jnp.array(0), node=True) # node
    step: int = to.field(default=1, node=False) # static

    def inc(self):
        self.n += self.step

counter = Counter(step=2) # Counter(n=jnp.array(0), step=2)

@jax.jit
def update(counter: Counter):
    counter.inc()
    return counter

counter = update(counter) # Counter(n=jnp.array(2), step=2)

# map over the tree

Full Example - Linear Regression

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import treeo as to


class Linear(to.Tree):
    w: jnp.ndarray = to.node()
    b: jnp.ndarray = to.node()

    def __init__(self, din, dout, key):
        self.w = jax.random.uniform(key, shape=(din, dout))
        self.b = jnp.zeros(shape=(dout,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b


@jax.value_and_grad
def loss_fn(model, x, y):
    y_pred = model(x)
    loss = jnp.mean((y_pred - y) ** 2)

    return loss


def sgd(param, grad):
    return param - 0.1 * grad


@jax.jit
def train_step(model, x, y):
    loss, grads = loss_fn(model, x, y)
    model = jax.tree_map(sgd, model, grads)

    return loss, model


x = np.random.uniform(size=(500, 1))
y = 1.4 * x - 0.3 + np.random.normal(scale=0.1, size=(500, 1))

key = jax.random.PRNGKey(0)
model = Linear(1, 1, key=key)

for step in range(1000):
    loss, model = train_step(model, x, y)
    if step % 100 == 0:
        print(f"loss: {loss:.4f}")

X_test = np.linspace(x.min(), x.max(), 100)[:, None]
y_pred = model(X_test)

plt.scatter(x, y, c="k", label="data")
plt.plot(X_test, y_pred, c="b", linewidth=2, label="prediction")
plt.legend()
plt.show()
Comments
  • Use field kinds within tree_map

    Use field kinds within tree_map

    Firstly, thanks for creating Treeo - it's a fantastic package.

    Is there a way to use methods defined within a field's kind object within a tree_map call? For example, consider the following MWE

    import jax.numpy as jnp
    
    class Parameter:
        def transform(self):
            return jnp.exp(self)
    
    
    @dataclass
    class Model(to.Tree):
        lengthscale: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=Parameter
        )
    

    is there a way that I could do something similar to the following pseudocode snippet:

    m = Model()
    jax.tree_map(lamdba x: x.transform(), to.filter(m, Parameter))
    
    opened by thomaspinder 10
  • Stacking of Treeo.Tree

    Stacking of Treeo.Tree

    I'm running into some issues when trying to stack a list of Treeo.Tree objects into a single object. I've made a short example:

    from dataclasses import dataclass
    
    import jax
    import jax.numpy as jnp
    import treeo as to
    
    @dataclass
    class Person(to.Tree):
        height: jnp.array = to.field(node=True) # I am a node field!
        age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
        name: str = to.field(node=False) # I am a static field!
    
    persons = [
        Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
        Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
        Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
    ]
    
    # Stack (struct of arrays instead of list of structs)
    jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    

    However, this fails with the following exception:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    Cell In[1], line 18
         11     name: str = to.field(node=False) # I am a static field!
         13 persons = [
         14     Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
         15     Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
         16     Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
         17 ]
    ---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').
    

    Versions used:

    • JAX: 0.3.20
    • Treeo: 0.0.10

    From a certain perspective this is expected because jax.tree_map does not apply to static (node=False) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields. Has anyone has tried something similar and come up with a nice solution?

    opened by peterroelants 3
  • Jitting twice for a class method

    Jitting twice for a class method

    import jax
    import jax.numpy as jnp
    import treeo as to
    
    class A(to.Tree):
        X: jnp.array = to.field(node=True)
        
        def __init__(self):
            self.X = jnp.ones((50, 50))
    
        @jax.jit
        def f(self, Y):
            return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)
    
    Y = jnp.ones(2)
    for i in range(5):
        print(A.f._cache_size())
        a = A()
        a.f(Y)
    

    The output of the above is 0 1 2 2 2 with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is 0 1 1 1 1. Thanks.

    opened by pipme 2
  • Change Mutable API

    Change Mutable API

    Changes

    • Previously self.mutable(*args, method=method, **kwargs)
    • Is now...... self.mutable(method=method)(*args, **kwargs)
    • Opaque API is removed
    • inplace argument is now only available for apply.
    • Immutable.{mutable, toplevel_mutable} methods are removed.
    fix 
    opened by cgarciae 1
  • Improve mutability support

    Improve mutability support

    Changes

    • Fixes issues with immutability in compact context
    • The make_mutable context manager and the mutable function now expose a toplevel_only: bool argument.
    • Adds a _get_unbound_method private function in utils.
    feature 
    opened by cgarciae 1
  • Bug Fixes from 0.0.11

    Bug Fixes from 0.0.11

    Changes

    • Fixes an issues that disabled mutability inside __init__ for Immutable classes when TreeMeta's `constructor method is overloaded.
    • Fixes the Apply.apply mixin method.

    Closes cgarciae/treex#68

    fix 
    opened by cgarciae 1
  • Adds support for immutable Trees

    Adds support for immutable Trees

    Changes

    • Adds an Immutable mixin that can make Trees effectively immutable (as far as python permits).
    • Immutable contains the .replace and .mutable methods that let you manipulate state in a functionally pure fashion.
    • Adds the mutable function transformation / decorator which lets you turn function that perform mutable operation into pure functions.
    opened by cgarciae 1
  • Add the option of using add_field_info inside map

    Add the option of using add_field_info inside map

    This PR addresses the comments made in #2 . An additional argument is created within map to allow for a field_info boolean flag to passed. When true, jax.tree_map is carried out under the with add_field_info(): context manager.

    Tests have been added to test for correct function application on classes contain Trees with mixed kind types.

    A brief section has been added to the documentation to reflect the above changes.

    opened by thomaspinder 1
  • Get all unique kinds

    Get all unique kinds

    Hi,

    Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:

    class KindOne: pass
    class KindTwo: pass
    
    @dataclass
    class SubModel(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindOne
        )
    
    
    @dataclass 
    class Model(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindTwo
        )
    
    m = Model()
    
    m.unique_kinds() # [KindOne, KindTwo]
    
    opened by thomaspinder 1
  • Compact

    Compact

    Changes

    • Removes opaque_is_equal, same functionality available through opaque.
    • Adds compact decorator that enable the definition of Tree subnodes at runtime.
    • Adds the Compact mixin that adds the first_run property and the get_field method.
    opened by cgarciae 0
  • Relax jax/jaxlib version constraints

    Relax jax/jaxlib version constraints

    Now that jax 0.3.0 and jaxlib 0.3.0 have been released the version constraints in pyproject.toml are outdated.

    https://github.com/cgarciae/treeo/blob/a402f3f69557840cfbee4d7804964b8e2c47e3f7/pyproject.toml#L16-L17

    This corresponds to the version constraint jax<0.3.0,>=0.2.18 (https://python-poetry.org/docs/dependency-specification/#caret-requirements). Now that jax v0.3.0 has been released (https://github.com/google/jax/releases/tag/jax-v0.3.0) this doesn't work with the latest version. I think the same applies to jaxlib as well, since it also got upgraded to v0.3.0 (https://github.com/google/jax/releases/tag/jaxlib-v0.3.0).

    opened by samuela 4
  • TracedArrays treated as nodes by default

    TracedArrays treated as nodes by default

    Current for convenience all non-Tree fields which are not declared are set to static fields as most fields actually are, however, for more complex applications a Traced Array might actually be passed when a static field is usually expected.

    A simple solution is change the current node policy to treat any field containing a TracedArray as a node, this would be the same as the current policy for Tree fields.

    opened by cgarciae 0
Releases(0.2.1)
Owner
Cristian Garcia
ML Engineer at Quansight, working on Treex and Elegy.
Cristian Garcia
QHack—the quantum machine learning hackathon

Official repo for QHack—the quantum machine learning hackathon

Xanadu 72 Dec 21, 2022
Weighted QMIX: Expanding Monotonic Value Function Factorisation

This repo contains the cleaned-up code that was used in "Weighted QMIX: Expanding Monotonic Value Function Factorisation"

whirl 82 Dec 29, 2022
The official repository for our paper "The Neural Data Router: Adaptive Control Flow in Transformers Improves Systematic Generalization".

Codebase for learning control flow in transformers The official repository for our paper "The Neural Data Router: Adaptive Control Flow in Transformer

Csordás Róbert 24 Oct 15, 2022
Implementation of "StrengthNet: Deep Learning-based Emotion Strength Assessment for Emotional Speech Synthesis"

StrengthNet Implementation of "StrengthNet: Deep Learning-based Emotion Strength Assessment for Emotional Speech Synthesis" https://arxiv.org/abs/2110

RuiLiu 65 Dec 20, 2022
Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Sami BARCHID 2 Oct 20, 2022
Official Pytorch implementation of "Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes", CVPR 2022

Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes / 3DCrowdNet News 💪 3DCrowdNet achieves the state-of-the-art accuracy on 3D

Hongsuk Choi 113 Dec 21, 2022
3D HourGlass Networks for Human Pose Estimation Through Videos

3D-HourGlass-Network 3D CNN Based Hourglass Network for Human Pose Estimation (3D Human Pose) from videos. This was my summer'18 research project. Dis

Naman Jain 51 Jan 02, 2023
Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

NÜWA - Pytorch (wip) Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. This repository will be popul

Phil Wang 463 Dec 28, 2022
Visualizer for neural network, deep learning, and machine learning models

Netron is a viewer for neural network, deep learning and machine learning models. Netron supports ONNX (.onnx, .pb, .pbtxt), Keras (.h5, .keras), Tens

Lutz Roeder 21k Jan 06, 2023
This repository contains a set of codes to run (i.e., train, perform inference with, evaluate) a diarization method called EEND-vector-clustering.

EEND-vector clustering The EEND-vector clustering (End-to-End-Neural-Diarization-vector clustering) is a speaker diarization framework that integrates

45 Dec 26, 2022
Unofficial TensorFlow implementation of the Keyword Spotting Transformer model

Keyword Spotting Transformer This is the unofficial TensorFlow implementation of the Keyword Spotting Transformer model. This model is used to train o

Intelligent Machines Limited 8 May 11, 2022
A Python package for faster, safer, and simpler ML processes

Bender 🤖 A Python package for faster, safer, and simpler ML processes. Why use bender? Bender will make your machine learning processes, faster, safe

Otovo 6 Dec 13, 2022
Low Complexity Channel estimation with Neural Network Solutions

Interpolation-ResNet Invited paper for WSA 2021, called 'Low Complexity Channel estimation with Neural Network Solutions'. Low complexity residual con

Dianxin 10 Dec 10, 2022
Self-supervised Augmentation Consistency for Adapting Semantic Segmentation (CVPR 2021)

Self-supervised Augmentation Consistency for Adapting Semantic Segmentation This repository contains the official implementation of our paper: Self-su

Visual Inference Lab @TU Darmstadt 132 Dec 21, 2022
Dense matching library based on PyTorch

Dense Matching A general dense matching library based on PyTorch. For any questions, issues or recommendations, please contact Prune at

Prune Truong 399 Dec 28, 2022
Use unsupervised and supervised learning to predict stocks

AIAlpha: Multilayer neural network architecture for stock return prediction This project is meant to be an advanced implementation of stacked neural n

Vivek Palaniappan 1.5k Dec 26, 2022
A super lightweight Lagrangian model for calculating millions of trajectories using ERA5 data

Easy-ERA5-Trck Easy-ERA5-Trck Galleries Install Usage Repository Structure Module Files Version iteration Easy-ERA5-Trck is a super lightweight Lagran

Zhenning Li 26 Nov 19, 2022
The Self-Supervised Learner can be used to train a classifier with fewer labeled examples needed using self-supervised learning.

Published by SpaceML • About SpaceML • Quick Colab Example Self-Supervised Learner The Self-Supervised Learner can be used to train a classifier with

SpaceML 92 Nov 30, 2022
Tutorial: Introduction to Graph Machine Learning, with Jupyter notebooks

GraphMLTutorialNLDL22 Tutorial NLDL22: Introduction to Graph Machine Learning, with Jupyter notebooks This tutorial takes place during the conference

UiT Machine Learning Group 3 Jan 10, 2022
Anchor-free Oriented Proposal Generator for Object Detection

Anchor-free Oriented Proposal Generator for Object Detection Gong Cheng, Jiabao Wang, Ke Li, Xingxing Xie, Chunbo Lang, Yanqing Yao, Junwei Han, Intro

jbwang1997 56 Nov 15, 2022