JAX + dataclasses

Overview

jax_dataclasses

build mypy lint codecov

jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for:

  • Pytree registration. This allows dataclasses to be used at API boundaries in JAX. (necessary for function transformations, JIT, etc)
  • Serialization via flax.serialization.

Notably, jax_dataclasses is designed to work seamlessly with static analysis, including tools like mypy and jedi.

Heavily influenced by some great existing work; see Alternatives for comparisons.

Installation

pip install jax_dataclasses

Core interface

jax_dataclasses is meant to provide a drop-in replacement for dataclasses.dataclass:

  • jax_dataclasses.pytree_dataclass has the same interface as dataclasses.dataclass, but also registers the target class as a pytree container.
  • jax_dataclasses.static_field has the same interface as dataclasses.field, but will also mark the field as static. In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.

We also provide several aliases: jax_dataclasses.[field, asdict, astuples, is_dataclass, replace] are all identical to their counterparts in the standard dataclasses library.

Mutations

All dataclasses are automatically marked as frozen and thus immutable (even when no frozen= parameter is passed in). To make changes to nested structures easier, we provide an interface that will (a) make a copy of a pytree and (b) return a context in which any of that copy's contained dataclasses are temporarily mutable:

from jax import numpy as jnp
import jax_dataclasses

@jax_dataclasses.pytree_dataclass
class Node:
  child: jnp.ndarray

obj = Node(child=jnp.zeros(3))

with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
  # Make mutations to the dataclass. This is primarily useful for nested
  # dataclasses.
  #
  # Also does input validation: if the treedef, leaf shapes, or dtypes of `obj`
  # and `obj_updated` don't match, an AssertionError will be raised.
  # This can be disabled with a `validate=False` argument.
  obj_updated.child = jnp.ones(3)

print(obj)
print(obj_updated)

Alternatives

A few other solutions exist for automatically integrating dataclass-style objects into pytree structures. Great ones include: chex.dataclass, flax.struct, and tjax.dataclass. These all influenced this library.

The main differentiators of jax_dataclasses are:

  • Static analysis support. Libraries like dataclasses and attrs rely on tooling-specific custom plugins for static analysis, which don't exist for chex or flax. tjax has a custom mypy plugin to enable type checking, but isn't supported by other tools. Because @jax_dataclasses.pytree_dataclass has the same API as @dataclasses.dataclass, it can include pytree registration behavior at runtime while being treated as the standard decorator during static analysis. This means that all static checkers, language servers, and autocomplete engines that support the standard dataclasses library should work out of the box with jax_dataclasses.

  • Nested dataclasses. Making replacements/modifications in deeply nested dataclasses is generally very frustrating. The three alternatives all introduce a .replace(self, ...) method to dataclasses that's a bit more convenient than the traditional dataclasses.replace(obj, ...) API for shallow changes, but still becomes really cumbersome to use when dataclasses are nested. jax_dataclasses.copy_and_mutate() is introduced to address this.

  • Static field support. Parameters that should not be traced in JAX should be marked as static. This is supported in flax, tjax, and jax_dataclasses, but not chex.

  • Serialization. When working with flax, being able to serialize dataclasses is really handy. This is supported in flax.struct (naturally) and jax_dataclasses, but not chex or tjax.

Misc

This code was originally written for and factored out of jaxfg, where Nick Heppert provided valuable feedback!

Comments
  • Fix infinite loop for cycles in pytrees

    Fix infinite loop for cycles in pytrees

    I have a rather big dataclass to describe a robot model, that includes a graph of links and a list of joints. Each node of the graph references the parent link and all the child links. Each joint object references its parent and child links.

    When I try to copy_and_mutate any of these objects, maybe due to all this nesting, an infinite loop occurs. I suspect that the existing logic tries to unfreeze all the leafs of the pytree, but the high interconnection and the properties of mutable Python types lead to a never ending unfreezing process.

    This PR addresses this edge case by storing the list of IDs of objects already unfreezed. It solves my problem, and it should not add any noticeable performance degradation.

    cc @brentyi

    opened by diegoferigo 10
  • Delayed initialisation of static fields

    Delayed initialisation of static fields

    First of all, thank you for the amazing library! I have recently discovered jax_dataclasses and I have decided to port my messy JAX functional code to a more organised object-oriented code based on jax_dataclasses.

    In my application, I have some derived quantities of the attributes of the dataclass that are static values used to determine the shape of tensors during JIT compilation. I would like to include them as attribute of the dataclass, but I'm getting an error and I would like to know if there is workaround.

    Here is a simple example, where the attribute _sum is a derived static field that depends on the constant value of the array a.

    import jax
    import jax.numpy as jnp
    import jax_dataclasses as jdc
    
    @jdc.pytree_dataclass()
    class PyTreeDataclass:
        a: jnp.ndarray
        _sum: int = jdc.static_field(init=False, repr=False)
    
        def __post_init__(self):
            object.__setattr__(self, "_sum", self.a.sum().item())
    
    def print_pytree(obj):
        print(obj._sum)
    
    obj = PyTreeDataclass(jnp.arange(4))
    print_pytree(obj)
    jax.jit(print_pytree)(obj)
    

    The non-jitted version works, but when print_pytree is jitted I get the following error.

    File "jax_dataclasses_issue.py", line 14, in __post_init__
        object.__setattr__(self, "_sum", self.a.sum().item())
    AttributeError: 'bool' object has no attribute 'sum'
    

    Is there a way to compute in the __post_init__ the value of static fields not initialized in __init__ that depend on jnp.ndarray attributes of the dataclass?

    opened by lucagrementieri 4
  • `jax.tree_leaves` is deprecated

    `jax.tree_leaves` is deprecated

    The file jax_dataclasses/_copy_and_mutate.py raises many warnings complaining a deprecated function.

    FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    
    opened by lucagrementieri 1
  • Use jaxtyping to enrich type annotations

    Use jaxtyping to enrich type annotations

    I just discovered the jaxtyping library and I think it could be an interesting alternative to the current typing system proposed by jax_dataclasses.

    jaxtyping supports variable-size axes and symbolic expressions in terms of other variable-size axes, see https://github.com/google/jaxtyping/blob/main/API.md and it has very few requirements.

    Do you think that it could be added to jax_dataclasses?

    opened by lucagrementieri 4
  • Serialization of static fields?

    Serialization of static fields?

    Thanks for the handy library!

    I have a pytree_dataclass that contains a few static_fields that I would like to have serialized by the facilities in flax.serialize. I noticed that jax_dataclasses.asdict handles these, but that flax.serialization.to_state_dict and flax.serialization.to_bytes both ignore them. What is the correct way (if any) to have these fields included in flax's serialization? Should I be using another technique?

    import jax_dataclasses as jdc
    from jax import numpy as jnp
    import flax.serialization as fs
    
    
    @jdc.pytree_dataclass
    class Demo:
        a: jnp.ndarray = jnp.ones(3)
        b: bool = jdc.static_field(default=False)
    
    
    demo = Demo()
    print(f'{jdc.asdict(demo) = }')
    print(f'{fs.to_state_dict(demo) = }')
    print(f'{fs.from_bytes(Demo, fs.to_bytes(demo)) = }')
    
    # jdc.asdict(demo) = {'a': array([1., 1., 1.]), 'b': False}
    # fs.to_state_dict(demo) = {'a': DeviceArray([1., 1., 1.], dtype=float64)}
    # fs.from_bytes(Demo, fs.to_bytes(demo)) = {'a': array([1., 1., 1.])}
    

    Thanks in advance!

    opened by erdmann 3
Releases(v1.5.1)
Owner
Brent Yi
Brent Yi
Awesome Human Pose Estimation

Human Pose Estimation Related Publication

Zhe Wang 1.2k Dec 26, 2022
Contrastive Multi-View Representation Learning on Graphs

Contrastive Multi-View Representation Learning on Graphs This work introduces a self-supervised approach based on contrastive multi-view learning to l

Kaveh 208 Dec 23, 2022
CVPR 2021: "Generating Diverse Structure for Image Inpainting With Hierarchical VQ-VAE"

Diverse Structure Inpainting ArXiv | Papar | Supplementary Material | BibTex This repository is for the CVPR 2021 paper, "Generating Diverse Structure

152 Nov 04, 2022
Convolutional neural network that analyzes self-generated images in a variety of languages to find etymological similarities

This project is a convolutional neural network (CNN) that analyzes self-generated images in a variety of languages to find etymological similarities. Specifically, the goal is to prove that computer

1 Feb 03, 2022
Implementation of ECCV20 paper: the devil is in classification: a simple framework for long-tail object detection and instance segmentation

Implementation of our ECCV 2020 paper The Devil is in Classification: A Simple Framework for Long-tail Instance Segmentation This repo contains code o

twang 98 Sep 17, 2022
Robot Hacking Manual (RHM). From robotics to cybersecurity. Papers, notes and writeups from a journey into robot cybersecurity.

RHM: Robot Hacking Manual Download in PDF RHM v0.4 ┃ Read online The Robot Hacking Manual (RHM) is an introductory series about cybersecurity for robo

Víctor Mayoral Vilches 233 Dec 30, 2022
PyTorch implementation of the WarpedGANSpace: Finding non-linear RBF paths in GAN latent space (ICCV 2021)

Authors official PyTorch implementation of the "WarpedGANSpace: Finding non-linear RBF paths in GAN latent space" [ICCV 2021].

Christos Tzelepis 100 Dec 06, 2022
Hippocampal segmentation using the UNet network for each axis

Hipposeg Hippocampal segmentation using the UNet network for each axis, inspired by https://github.com/MICLab-Unicamp/e2dhipseg Red: False Positive Gr

Juan Carlos Aguirre Arango 0 Sep 02, 2021
PyTorch implementation of MoCo: Momentum Contrast for Unsupervised Visual Representation Learning

MoCo: Momentum Contrast for Unsupervised Visual Representation Learning This is a PyTorch implementation of the MoCo paper: @Article{he2019moco, aut

Meta Research 3.7k Jan 02, 2023
Exploit Camera Raw Data for Video Super-Resolution via Hidden Markov Model Inference

RawVSR This repo contains the official codes for our paper: Exploit Camera Raw Data for Video Super-Resolution via Hidden Markov Model Inference Xiaoh

Xiaohong Liu 23 Oct 08, 2022
Rethinking Transformer-based Set Prediction for Object Detection

Rethinking Transformer-based Set Prediction for Object Detection Here are the code for the ICCV paper. The code is adapted from Detectron2 and AdelaiD

Zhiqing Sun 62 Dec 03, 2022
Official Implementation of "Designing an Encoder for StyleGAN Image Manipulation"

Designing an Encoder for StyleGAN Image Manipulation (SIGGRAPH 2021) Recently, there has been a surge of diverse methods for performing image editing

749 Jan 09, 2023
[CVPR 2021] Exemplar-Based Open-Set Panoptic Segmentation Network (EOPSN)

EOPSN: Exemplar-Based Open-Set Panoptic Segmentation Network (CVPR 2021) PyTorch implementation for EOPSN. We propose open-set panoptic segmentation t

Jaedong Hwang 49 Dec 30, 2022
Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022)

Official code of Retinal Vessel Segmentation with Pixel-wise Adaptive Filters and Consistency Training (ISBI 2022)

anonymous 14 Oct 27, 2022
Distance-Ratio-Based Formulation for Metric Learning

Distance-Ratio-Based Formulation for Metric Learning Environment Python3 Pytorch (http://pytorch.org/) (version 1.6.0+cu101) json tqdm Preparing datas

Hyeongji Kim 1 Dec 07, 2022
A large-scale database for graph representation learning

A large-scale database for graph representation learning

Scott Freitas 29 Nov 25, 2022
This repository contains datasets and baselines for benchmarking Chinese text recognition.

Benchmarking-Chinese-Text-Recognition This repository contains datasets and baselines for benchmarking Chinese text recognition. Please see the corres

FudanVI Lab 254 Dec 30, 2022
Code for SIMMC 2.0: A Task-oriented Dialog Dataset for Immersive Multimodal Conversations

The Second Situated Interactive MultiModal Conversations (SIMMC 2.0) Challenge 2021 Welcome to the Second Situated Interactive Multimodal Conversation

Facebook Research 81 Nov 22, 2022
NeurIPS workshop paper 'Counter-Strike Deathmatch with Large-Scale Behavioural Cloning'

Counter-Strike Deathmatch with Large-Scale Behavioural Cloning Tim Pearce, Jun Zhu Offline RL workshop, NeurIPS 2021 Paper: https://arxiv.org/abs/2104

Tim Pearce 169 Dec 26, 2022
Tutel MoE: An Optimized Mixture-of-Experts Implementation

Project Tutel Tutel MoE: An Optimized Mixture-of-Experts Implementation. Supported Framework: Pytorch Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32) Ho

Microsoft 344 Dec 29, 2022