JAX + dataclasses

Overview

jax_dataclasses

build mypy lint codecov

jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for:

  • Pytree registration. This allows dataclasses to be used at API boundaries in JAX. (necessary for function transformations, JIT, etc)
  • Serialization via flax.serialization.

Notably, jax_dataclasses is designed to work seamlessly with static analysis, including tools like mypy and jedi.

Heavily influenced by some great existing work; see Alternatives for comparisons.

Installation

pip install jax_dataclasses

Core interface

jax_dataclasses is meant to provide a drop-in replacement for dataclasses.dataclass:

  • jax_dataclasses.pytree_dataclass has the same interface as dataclasses.dataclass, but also registers the target class as a pytree container.
  • jax_dataclasses.static_field has the same interface as dataclasses.field, but will also mark the field as static. In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.

We also provide several aliases: jax_dataclasses.[field, asdict, astuples, is_dataclass, replace] are all identical to their counterparts in the standard dataclasses library.

Mutations

All dataclasses are automatically marked as frozen and thus immutable (even when no frozen= parameter is passed in). To make changes to nested structures easier, we provide an interface that will (a) make a copy of a pytree and (b) return a context in which any of that copy's contained dataclasses are temporarily mutable:

from jax import numpy as jnp
import jax_dataclasses

@jax_dataclasses.pytree_dataclass
class Node:
  child: jnp.ndarray

obj = Node(child=jnp.zeros(3))

with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
  # Make mutations to the dataclass. This is primarily useful for nested
  # dataclasses.
  #
  # Also does input validation: if the treedef, leaf shapes, or dtypes of `obj`
  # and `obj_updated` don't match, an AssertionError will be raised.
  # This can be disabled with a `validate=False` argument.
  obj_updated.child = jnp.ones(3)

print(obj)
print(obj_updated)

Alternatives

A few other solutions exist for automatically integrating dataclass-style objects into pytree structures. Great ones include: chex.dataclass, flax.struct, and tjax.dataclass. These all influenced this library.

The main differentiators of jax_dataclasses are:

  • Static analysis support. Libraries like dataclasses and attrs rely on tooling-specific custom plugins for static analysis, which don't exist for chex or flax. tjax has a custom mypy plugin to enable type checking, but isn't supported by other tools. Because @jax_dataclasses.pytree_dataclass has the same API as @dataclasses.dataclass, it can include pytree registration behavior at runtime while being treated as the standard decorator during static analysis. This means that all static checkers, language servers, and autocomplete engines that support the standard dataclasses library should work out of the box with jax_dataclasses.

  • Nested dataclasses. Making replacements/modifications in deeply nested dataclasses is generally very frustrating. The three alternatives all introduce a .replace(self, ...) method to dataclasses that's a bit more convenient than the traditional dataclasses.replace(obj, ...) API for shallow changes, but still becomes really cumbersome to use when dataclasses are nested. jax_dataclasses.copy_and_mutate() is introduced to address this.

  • Static field support. Parameters that should not be traced in JAX should be marked as static. This is supported in flax, tjax, and jax_dataclasses, but not chex.

  • Serialization. When working with flax, being able to serialize dataclasses is really handy. This is supported in flax.struct (naturally) and jax_dataclasses, but not chex or tjax.

Misc

This code was originally written for and factored out of jaxfg, where Nick Heppert provided valuable feedback!

Comments
  • Fix infinite loop for cycles in pytrees

    Fix infinite loop for cycles in pytrees

    I have a rather big dataclass to describe a robot model, that includes a graph of links and a list of joints. Each node of the graph references the parent link and all the child links. Each joint object references its parent and child links.

    When I try to copy_and_mutate any of these objects, maybe due to all this nesting, an infinite loop occurs. I suspect that the existing logic tries to unfreeze all the leafs of the pytree, but the high interconnection and the properties of mutable Python types lead to a never ending unfreezing process.

    This PR addresses this edge case by storing the list of IDs of objects already unfreezed. It solves my problem, and it should not add any noticeable performance degradation.

    cc @brentyi

    opened by diegoferigo 10
  • Delayed initialisation of static fields

    Delayed initialisation of static fields

    First of all, thank you for the amazing library! I have recently discovered jax_dataclasses and I have decided to port my messy JAX functional code to a more organised object-oriented code based on jax_dataclasses.

    In my application, I have some derived quantities of the attributes of the dataclass that are static values used to determine the shape of tensors during JIT compilation. I would like to include them as attribute of the dataclass, but I'm getting an error and I would like to know if there is workaround.

    Here is a simple example, where the attribute _sum is a derived static field that depends on the constant value of the array a.

    import jax
    import jax.numpy as jnp
    import jax_dataclasses as jdc
    
    @jdc.pytree_dataclass()
    class PyTreeDataclass:
        a: jnp.ndarray
        _sum: int = jdc.static_field(init=False, repr=False)
    
        def __post_init__(self):
            object.__setattr__(self, "_sum", self.a.sum().item())
    
    def print_pytree(obj):
        print(obj._sum)
    
    obj = PyTreeDataclass(jnp.arange(4))
    print_pytree(obj)
    jax.jit(print_pytree)(obj)
    

    The non-jitted version works, but when print_pytree is jitted I get the following error.

    File "jax_dataclasses_issue.py", line 14, in __post_init__
        object.__setattr__(self, "_sum", self.a.sum().item())
    AttributeError: 'bool' object has no attribute 'sum'
    

    Is there a way to compute in the __post_init__ the value of static fields not initialized in __init__ that depend on jnp.ndarray attributes of the dataclass?

    opened by lucagrementieri 4
  • `jax.tree_leaves` is deprecated

    `jax.tree_leaves` is deprecated

    The file jax_dataclasses/_copy_and_mutate.py raises many warnings complaining a deprecated function.

    FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    
    opened by lucagrementieri 1
  • Use jaxtyping to enrich type annotations

    Use jaxtyping to enrich type annotations

    I just discovered the jaxtyping library and I think it could be an interesting alternative to the current typing system proposed by jax_dataclasses.

    jaxtyping supports variable-size axes and symbolic expressions in terms of other variable-size axes, see https://github.com/google/jaxtyping/blob/main/API.md and it has very few requirements.

    Do you think that it could be added to jax_dataclasses?

    opened by lucagrementieri 4
  • Serialization of static fields?

    Serialization of static fields?

    Thanks for the handy library!

    I have a pytree_dataclass that contains a few static_fields that I would like to have serialized by the facilities in flax.serialize. I noticed that jax_dataclasses.asdict handles these, but that flax.serialization.to_state_dict and flax.serialization.to_bytes both ignore them. What is the correct way (if any) to have these fields included in flax's serialization? Should I be using another technique?

    import jax_dataclasses as jdc
    from jax import numpy as jnp
    import flax.serialization as fs
    
    
    @jdc.pytree_dataclass
    class Demo:
        a: jnp.ndarray = jnp.ones(3)
        b: bool = jdc.static_field(default=False)
    
    
    demo = Demo()
    print(f'{jdc.asdict(demo) = }')
    print(f'{fs.to_state_dict(demo) = }')
    print(f'{fs.from_bytes(Demo, fs.to_bytes(demo)) = }')
    
    # jdc.asdict(demo) = {'a': array([1., 1., 1.]), 'b': False}
    # fs.to_state_dict(demo) = {'a': DeviceArray([1., 1., 1.], dtype=float64)}
    # fs.from_bytes(Demo, fs.to_bytes(demo)) = {'a': array([1., 1., 1.])}
    

    Thanks in advance!

    opened by erdmann 3
Releases(v1.5.1)
Owner
Brent Yi
Brent Yi
Neural Turing Machine (NTM) & Differentiable Neural Computer (DNC) with pytorch & visdom

Neural Turing Machine (NTM) & Differentiable Neural Computer (DNC) with pytorch & visdom Sample on-line plotting while training(avg loss)/testing(writ

Jingwei Zhang 269 Nov 15, 2022
Implementation of Gans

GAN Generative Adverserial Networks are an approach to generative data modelling using Deep learning methods. I have currently implemented : DCGAN on

Sibam Parida 5 Sep 07, 2021
Point cloud processing tool library.

Point Cloud ToolBox This point cloud processing tool library can be used to process point clouds, 3d meshes, and voxels. Environment python 3.7.5 Dep

ZhangXinyun 40 Dec 09, 2022
ConvMAE: Masked Convolution Meets Masked Autoencoders

ConvMAE ConvMAE: Masked Convolution Meets Masked Autoencoders Peng Gao1, Teli Ma1, Hongsheng Li2, Jifeng Dai3, Yu Qiao1, 1 Shanghai AI Laboratory, 2 M

Alpha VL Team of Shanghai AI Lab 345 Jan 08, 2023
Consensus Learning from Heterogeneous Objectives for One-Class Collaborative Filtering

Consensus Learning from Heterogeneous Objectives for One-Class Collaborative Filtering This repository provides the source code of "Consensus Learning

SeongKu-Kang 6 Apr 29, 2022
The official implementation of Equalization Loss for Long-Tailed Object Recognition (CVPR 2020) based on Detectron2

Equalization Loss for Long-Tailed Object Recognition Jingru Tan, Changbao Wang, Buyu Li, Quanquan Li, Wanli Ouyang, Changqing Yin, Junjie Yan ⚠️ We re

Jingru Tan 197 Dec 25, 2022
An example of Scatterbrain implementation (combining local attention and Performer)

An example of Scatterbrain implementation (combining local attention and Performer)

HazyResearch 97 Jan 02, 2023
Yolov5+SlowFast: Realtime Action Detection Based on PytorchVideo

Yolov5+SlowFast: Realtime Action Detection A realtime action detection frame work based on PytorchVideo. Here are some details about our modification:

WuFan 181 Dec 30, 2022
Code for database and frontend of webpage for Neural Fields in Visual Computing and Beyond.

Neural Fields in Visual Computing—Complementary Webpage This is based on the amazing MiniConf project from Hendrik Strobelt and Sasha Rush—thank you!

Brown University Visual Computing Group 29 Nov 30, 2022
HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis

HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis Jungil Kong, Jaehyeon Kim, Jaekyoung Bae In our paper, we p

Rishikesh (ऋषिकेश) 31 Dec 08, 2022
A data-driven maritime port simulator

PySeidon - A Data-Driven Maritime Port Simulator 🌊 Extendable and modular software for maritime port simulation. This software uses entity-component

6 Apr 10, 2022
[ICLR 2021] HW-NAS-Bench: Hardware-Aware Neural Architecture Search Benchmark

HW-NAS-Bench: Hardware-Aware Neural Architecture Search Benchmark Accepted as a spotlight paper at ICLR 2021. Table of content File structure Prerequi

72 Jan 03, 2023
This repository contains the implementations related to the experiments of a set of publicly available datasets that are used in the time series forecasting research space.

TSForecasting This repository contains the implementations related to the experiments of a set of publicly available datasets that are used in the tim

Rakshitha Godahewa 80 Dec 30, 2022
Node Editor Plug for Blender

NodeEditor Blender的程序化建模插件 Show Current 基本框架:自定义的tree-node-socket、tree中的node与socket采用字典查询、基于socket入度的拓扑排序 数据传递和处理依靠Tree中的字典,socket传递字典key TODO 增加更多的节点

Cuimi 11 Dec 03, 2022
Differentiable Quantum Chemistry (only Differentiable Density Functional Theory and Hartree Fock at the moment)

DQC: Differentiable Quantum Chemistry Differentiable quantum chemistry package. Currently only support differentiable density functional theory (DFT)

75 Dec 02, 2022
NeWT: Natural World Tasks

NeWT: Natural World Tasks This repository contains resources for working with the NeWT dataset. ❗ At this time the binary tasks are not publicly avail

Visipedia 26 Oct 18, 2022
Roger Labbe 13k Dec 29, 2022
This repository is based on Ultralytics/yolov5, with adjustments to enable rotate prediction boxes.

Rotate-Yolov5 This repository is based on Ultralytics/yolov5, with adjustments to enable rotate prediction boxes. Section I. Description The codes are

xinzelee 90 Dec 13, 2022
VM3000 Microphones

VM3000-Microphones This project was completed by Ricky Leman under the supervision of Dr Ben Travaglione and Professor Melinda Hodkiewicz as part of t

UWA System Health Lab 0 Jun 04, 2021
Python implementation of cover trees, near-drop-in replacement for scipy.spatial.kdtree

This is a Python implementation of cover trees, a data structure for finding nearest neighbors in a general metric space (e.g., a 3D box with periodic

Patrick Varilly 28 Nov 25, 2022