A small library for creating and manipulating custom JAX Pytree classes

Overview

Treeo

A small library for creating and manipulating custom JAX Pytree classes

  • Light-weight: has no dependencies other than jax.
  • Compatible: Treeo Tree objects are compatible with any jax function that accepts Pytrees.
  • Standards-based: treeo.field is built on top of python's dataclasses.field.
  • Flexible: Treeo is compatible with both dataclass and non-dataclass classes.

Treeo lets you easily create class-based Pytrees so your custom objects can easily interact seamlessly with JAX. Uses of Treeo can range from just creating simple simple JAX-aware utility classes to using it as the core abstraction for full-blown frameworks. Treeo was originally extracted from the core of Treex and shares a lot in common with flax.struct.

Documentation | User Guide

Installation

Install using pip:

pip install treeo

Basics

With Treeo you can easily define your own custom Pytree classes by inheriting from Treeo's Tree class and using the field function to declare which fields are nodes (children) and which are static (metadata):

import treeo as to

@dataclass
class Person(to.Tree):
    height: jnp.array = to.field(node=True) # I am a node field!
    name: str = to.field(node=False) # I am a static field!

field is just a wrapper around dataclasses.field so you can define your Pytrees as dataclasses, but Treeo fully supports non-dataclass classes as well. Since all Tree instances are Pytree they work with the various functions from thejax library as expected:

p = Person(height=jnp.array(1.8), name="John")

# Trees can be jitted!
jax.jit(lambda person: person)(p) # Person(height=array(1.8), name='John')

# Trees can be mapped!
jax.tree_map(lambda x: 2 * x, p) # Person(height=array(3.6), name='John')

Kinds

Treeo also include a kind system that lets you give semantic meaning to fields (what a field represents within your application). A kind is just a type you pass to field via its kind argument:

class Parameter: pass
class BatchStat: pass

class BatchNorm(to.Tree):
    scale: jnp.ndarray = to.field(node=True, kind=Parameter)
    mean: jnp.ndarray = to.field(node=True, kind=BatchStat)

Kinds are very useful as a filtering mechanism via treeo.filter:

model = BatchNorm(...)

# select only Parameters, mean is filtered out
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)

Nothing behaves like None in Python, but it is a special value that is used to represent the absence of a value within Treeo.

Treeo also offers the merge function which lets you rejoin filtered Trees with a logic similar to Python dict.update but done recursively:

def loss_fn(params, model, ...):
    # add traced params to model
    model = to.merge(model, params)
    ...

# gradient only w.r.t. params
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)
grads = jax.grad(loss_fn)(params, model, ...)

For a more in-depth tour check out the User Guide.

Examples

A simple Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Character(to.Tree):
    position: jnp.ndarray = to.field(node=True)    # node field
    name: str = to.field(node=False, opaque=True)  # static field

character = Character(position=jnp.array([0, 0]), name='Adam')

# character can freely pass through jit
@jax.jit
def update(character: Character, velocity, dt) -> Character:
    character.position += velocity * dt
    return character

character = update(character velocity=jnp.array([1.0, 0.2]), dt=0.1)

A Stateful Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Counter(to.Tree):
    n: jnp.array = to.field(default=jnp.array(0), node=True) # node
    step: int = to.field(default=1, node=False) # static

    def inc(self):
        self.n += self.step

counter = Counter(step=2) # Counter(n=jnp.array(0), step=2)

@jax.jit
def update(counter: Counter):
    counter.inc()
    return counter

counter = update(counter) # Counter(n=jnp.array(2), step=2)

# map over the tree

Full Example - Linear Regression

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import treeo as to


class Linear(to.Tree):
    w: jnp.ndarray = to.node()
    b: jnp.ndarray = to.node()

    def __init__(self, din, dout, key):
        self.w = jax.random.uniform(key, shape=(din, dout))
        self.b = jnp.zeros(shape=(dout,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b


@jax.value_and_grad
def loss_fn(model, x, y):
    y_pred = model(x)
    loss = jnp.mean((y_pred - y) ** 2)

    return loss


def sgd(param, grad):
    return param - 0.1 * grad


@jax.jit
def train_step(model, x, y):
    loss, grads = loss_fn(model, x, y)
    model = jax.tree_map(sgd, model, grads)

    return loss, model


x = np.random.uniform(size=(500, 1))
y = 1.4 * x - 0.3 + np.random.normal(scale=0.1, size=(500, 1))

key = jax.random.PRNGKey(0)
model = Linear(1, 1, key=key)

for step in range(1000):
    loss, model = train_step(model, x, y)
    if step % 100 == 0:
        print(f"loss: {loss:.4f}")

X_test = np.linspace(x.min(), x.max(), 100)[:, None]
y_pred = model(X_test)

plt.scatter(x, y, c="k", label="data")
plt.plot(X_test, y_pred, c="b", linewidth=2, label="prediction")
plt.legend()
plt.show()
Comments
  • Use field kinds within tree_map

    Use field kinds within tree_map

    Firstly, thanks for creating Treeo - it's a fantastic package.

    Is there a way to use methods defined within a field's kind object within a tree_map call? For example, consider the following MWE

    import jax.numpy as jnp
    
    class Parameter:
        def transform(self):
            return jnp.exp(self)
    
    
    @dataclass
    class Model(to.Tree):
        lengthscale: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=Parameter
        )
    

    is there a way that I could do something similar to the following pseudocode snippet:

    m = Model()
    jax.tree_map(lamdba x: x.transform(), to.filter(m, Parameter))
    
    opened by thomaspinder 10
  • Stacking of Treeo.Tree

    Stacking of Treeo.Tree

    I'm running into some issues when trying to stack a list of Treeo.Tree objects into a single object. I've made a short example:

    from dataclasses import dataclass
    
    import jax
    import jax.numpy as jnp
    import treeo as to
    
    @dataclass
    class Person(to.Tree):
        height: jnp.array = to.field(node=True) # I am a node field!
        age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
        name: str = to.field(node=False) # I am a static field!
    
    persons = [
        Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
        Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
        Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
    ]
    
    # Stack (struct of arrays instead of list of structs)
    jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    

    However, this fails with the following exception:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    Cell In[1], line 18
         11     name: str = to.field(node=False) # I am a static field!
         13 persons = [
         14     Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
         15     Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
         16     Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
         17 ]
    ---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').
    

    Versions used:

    • JAX: 0.3.20
    • Treeo: 0.0.10

    From a certain perspective this is expected because jax.tree_map does not apply to static (node=False) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields. Has anyone has tried something similar and come up with a nice solution?

    opened by peterroelants 3
  • Jitting twice for a class method

    Jitting twice for a class method

    import jax
    import jax.numpy as jnp
    import treeo as to
    
    class A(to.Tree):
        X: jnp.array = to.field(node=True)
        
        def __init__(self):
            self.X = jnp.ones((50, 50))
    
        @jax.jit
        def f(self, Y):
            return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)
    
    Y = jnp.ones(2)
    for i in range(5):
        print(A.f._cache_size())
        a = A()
        a.f(Y)
    

    The output of the above is 0 1 2 2 2 with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is 0 1 1 1 1. Thanks.

    opened by pipme 2
  • Change Mutable API

    Change Mutable API

    Changes

    • Previously self.mutable(*args, method=method, **kwargs)
    • Is now...... self.mutable(method=method)(*args, **kwargs)
    • Opaque API is removed
    • inplace argument is now only available for apply.
    • Immutable.{mutable, toplevel_mutable} methods are removed.
    fix 
    opened by cgarciae 1
  • Improve mutability support

    Improve mutability support

    Changes

    • Fixes issues with immutability in compact context
    • The make_mutable context manager and the mutable function now expose a toplevel_only: bool argument.
    • Adds a _get_unbound_method private function in utils.
    feature 
    opened by cgarciae 1
  • Bug Fixes from 0.0.11

    Bug Fixes from 0.0.11

    Changes

    • Fixes an issues that disabled mutability inside __init__ for Immutable classes when TreeMeta's `constructor method is overloaded.
    • Fixes the Apply.apply mixin method.

    Closes cgarciae/treex#68

    fix 
    opened by cgarciae 1
  • Adds support for immutable Trees

    Adds support for immutable Trees

    Changes

    • Adds an Immutable mixin that can make Trees effectively immutable (as far as python permits).
    • Immutable contains the .replace and .mutable methods that let you manipulate state in a functionally pure fashion.
    • Adds the mutable function transformation / decorator which lets you turn function that perform mutable operation into pure functions.
    opened by cgarciae 1
  • Add the option of using add_field_info inside map

    Add the option of using add_field_info inside map

    This PR addresses the comments made in #2 . An additional argument is created within map to allow for a field_info boolean flag to passed. When true, jax.tree_map is carried out under the with add_field_info(): context manager.

    Tests have been added to test for correct function application on classes contain Trees with mixed kind types.

    A brief section has been added to the documentation to reflect the above changes.

    opened by thomaspinder 1
  • Get all unique kinds

    Get all unique kinds

    Hi,

    Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:

    class KindOne: pass
    class KindTwo: pass
    
    @dataclass
    class SubModel(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindOne
        )
    
    
    @dataclass 
    class Model(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindTwo
        )
    
    m = Model()
    
    m.unique_kinds() # [KindOne, KindTwo]
    
    opened by thomaspinder 1
  • Compact

    Compact

    Changes

    • Removes opaque_is_equal, same functionality available through opaque.
    • Adds compact decorator that enable the definition of Tree subnodes at runtime.
    • Adds the Compact mixin that adds the first_run property and the get_field method.
    opened by cgarciae 0
  • Relax jax/jaxlib version constraints

    Relax jax/jaxlib version constraints

    Now that jax 0.3.0 and jaxlib 0.3.0 have been released the version constraints in pyproject.toml are outdated.

    https://github.com/cgarciae/treeo/blob/a402f3f69557840cfbee4d7804964b8e2c47e3f7/pyproject.toml#L16-L17

    This corresponds to the version constraint jax<0.3.0,>=0.2.18 (https://python-poetry.org/docs/dependency-specification/#caret-requirements). Now that jax v0.3.0 has been released (https://github.com/google/jax/releases/tag/jax-v0.3.0) this doesn't work with the latest version. I think the same applies to jaxlib as well, since it also got upgraded to v0.3.0 (https://github.com/google/jax/releases/tag/jaxlib-v0.3.0).

    opened by samuela 4
  • TracedArrays treated as nodes by default

    TracedArrays treated as nodes by default

    Current for convenience all non-Tree fields which are not declared are set to static fields as most fields actually are, however, for more complex applications a Traced Array might actually be passed when a static field is usually expected.

    A simple solution is change the current node policy to treat any field containing a TracedArray as a node, this would be the same as the current policy for Tree fields.

    opened by cgarciae 0
Releases(0.2.1)
Owner
Cristian Garcia
ML Engineer at Quansight, working on Treex and Elegy.
Cristian Garcia
Python Jupyter kernel using Poetry for reproducible notebooks

Poetry Kernel Use per-directory Poetry environments to run Jupyter kernels. No need to install a Jupyter kernel per Python virtual environment! The id

Pathbird 204 Jan 04, 2023
A high-performance Python-based I/O system for large (and small) deep learning problems, with strong support for PyTorch.

WebDataset WebDataset is a PyTorch Dataset (IterableDataset) implementation providing efficient access to datasets stored in POSIX tar archives and us

1.1k Jan 08, 2023
A PyTorch implementation of unsupervised SimCSE

A PyTorch implementation of unsupervised SimCSE

99 Dec 23, 2022
Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, Daniel Silva, Andrew McCallum, Amr Ahmed. KDD 2019.

gHHC Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, D

Nicholas Monath 35 Nov 16, 2022
TACTO: A Fast, Flexible and Open-source Simulator for High-Resolution Vision-based Tactile Sensors

TACTO: A Fast, Flexible and Open-source Simulator for High-Resolution Vision-based Tactile Sensors This package provides a simulator for vision-based

Facebook Research 255 Dec 27, 2022
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion" Coming soon, as soon as I finish a

Ziyao Zeng 14 Feb 26, 2022
Unofficial PyTorch implementation of SimCLR by Google Brain

Unofficial PyTorch implementation of SimCLR by Google Brain

Rishabh Anand 2 Oct 13, 2021
NAS Benchmark in "Prioritized Architecture Sampling with Monto-Carlo Tree Search", CVPR2021

NAS-Bench-Macro This repository includes the benchmark and code for NAS-Bench-Macro in paper "Prioritized Architecture Sampling with Monto-Carlo Tree

35 Jan 03, 2023
RLBot Python bindings for the Rust crate rl_ball_sym

RLBot Python bindings for rl_ball_sym 0.6 Prerequisites: Rust & Cargo Build Tools for Visual Studio RLBot - Verify that the file %localappdata%\RLBotG

Eric Veilleux 2 Nov 25, 2022
Implementation of: "Exploring Randomly Wired Neural Networks for Image Recognition"

RandWireNN Unofficial PyTorch Implementation of: Exploring Randomly Wired Neural Networks for Image Recognition. Results Validation result on Imagenet

Seung-won Park 684 Nov 02, 2022
Robust fine-tuning of zero-shot models

Robust fine-tuning of zero-shot models This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabri

224 Dec 29, 2022
PaSST: Efficient Training of Audio Transformers with Patchout

PaSST: Efficient Training of Audio Transformers with Patchout This is the implementation for Efficient Training of Audio Transformers with Patchout Pa

165 Dec 26, 2022
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. Diffrax is a JAX-based library providing numerical differe

Patrick Kidger 717 Jan 09, 2023
Code for "Layered Neural Rendering for Retiming People in Video."

Layered Neural Rendering in PyTorch This repository contains training code for the examples in the SIGGRAPH Asia 2020 paper "Layered Neural Rendering

Google 154 Dec 16, 2022
This code is the implementation of the paper "Coherence-Based Distributed Document Representation Learning for Scientific Documents".

Introduction This code is the implementation of the paper "Coherence-Based Distributed Document Representation Learning for Scientific Documents". If

tsc 0 Jan 11, 2022
Medical image analysis framework merging ANTsPy and deep learning

ANTsPyNet A collection of deep learning architectures and applications ported to the python language and tools for basic medical image processing. Bas

Advanced Normalization Tools Ecosystem 118 Dec 24, 2022
VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations

VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations 3D-aware Image Synthesis via Learning Structural and Textura

GenForce: May Generative Force Be with You 116 Dec 26, 2022
ConvMixer unofficial implementation

ConvMixer ConvMixer 非官方实现 pytorch 版本已经实现。 nets 是重构版本 ,test 是官方代码 感兴趣小伙伴可以对照看一下。 keras 已经实现 tf2.x 中 是tensorflow 2 版本 gelu 激活函数要求 tf=2.4 否则使用入下代码代替gelu

Jian Tengfei 8 Jul 11, 2022