Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Related tags

Deep Learningobjax
Overview

Objax

Tutorials | Install | Documentation | Philosophy

This is not an officially supported Google product.

Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX -- a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.

This is the developer repository of Objax, there is very little user documentation here, for the full documentation go to objax.readthedocs.io.

You can find READMEs in the subdirectory of this project, for example:

User installation guide

You install Objax using pip as follows:

pip install --upgrade objax

Objax supports GPUs but assumes that you already have some version of CUDA installed. Here are the extra steps:

# Update accordingly to your installed CUDA version
CUDA_VERSION=11.0
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`

Useful environment configurations

Here are a few useful options:

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false

Testing your installation

You can test your installation by running the code below:

import jax
import objax

print(f'Number of GPUs {jax.device_count()}')

x = objax.random.normal(shape=(100, 4))
m = objax.nn.Linear(nin=4, nout=5)
print('Matrix product shape', m(x).shape)  # (100, 5)

x = objax.random.normal(shape=(100, 3, 32, 32))
m = objax.nn.Conv2D(nin=3, nout=4, k=3)
print('Conv2D return shape', m(x).shape)  # (100, 4, 32, 32)

Typically if you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.

Runing code examples

Clone the code repository:

git clone https://github.com/google/objax.git
cd objax/examples

Citing Objax

To cite this repository:

@software{objax2020github,
  author = {{Objax Developers}},
  title = {{Objax}},
  url = {https://github.com/google/objax},
  version = {1.2.0},
  year = {2020},
}

Developer documentation

Here is information about development setup and a guide on adding new code.

Comments
  • More control over var/module namespace.

    More control over var/module namespace.

    I got my first 'hello world' model experiment working w/ Objax. I adapted my PyTorch EfficientNet impl. Overall pretty smooth, currently wrapping Conv2d so I can get the padding I want.

    One thing that stuck out after inspecting the model, the var namespace is a mess. An aspect of modelling that I value highly is the ability to have sensible checkpoint/var maps to work with. I often end up dealing with conversions between frameworks, exports for mobile or embedded targets and having your vars (parameters) sensibly named, and often being able to control those names in the originating framework is important.

    Any thoughts on improving this? The current name/scoping mechanism forces the inclusion of the Module class names, is that necessary? Shouldn't attr names through the tree be enough for uniqueness?

    Also, there is no ability to specify names for modules in sequential containers. I use this quite often for frameworks that have it. Sometimes I don't care much (long list of block repeats, 0..n is fine), but for finer grained blocks I like to know what conv is what by looking at the var names. '0.b, o.w' etc isn't very useful.

    I'll post an example of the var keys below, and comparison point for pytorch.

    feature request 
    opened by rwightman 29
  • upsample2d function rough draft

    upsample2d function rough draft

    Hi Team, i am pretty new to contributing in opensource projects. Please have a review of the upsample2d function and let me know of anything that is required or should be changed. the function is added in objax.function.ops module.

    opened by naruto-raj 22
  • Add mean squared logarithmic loss function

    Add mean squared logarithmic loss function

    1. Added mean squared logarithmic loss function
    2. In the CONTRIBUTIONS.md file, there is no mention of code-style. So, I am using 4-spaces.
    3. I haven't formatted the code using black as there is no mention of any formatter as well.

    I will add the tests once the above points are clear

    opened by AakashKumarNain 16
  • Initial dot product attention

    Initial dot product attention

    Adds attention, per #61 So, first I'm really sorry about taking so long, but college got complicated in the pandemic and I wasted a lot of time getting organized. Also, Attention is a quite general concept, and even implementations of the same type of attention differ significantly (haiku, flax) So @david-berthelot and @aterzis-google I would like to ask a few questions just to make sure my implementation is going in the right direction

    1. I think I will implement a dot product attention, a multi-head attention and a masked attention, is that ok?
    2. What do you think of the dot product attention implementation? What do you think I need to change? Thanks for the patience and opportunity.
    opened by joaogui1 12
  • "objax.variable.VarCollection is not a valid JAX type" when creating a custom optimizer

    Hi, I wish to create a custom optimizer to replace the opt(lr=lr, grads=g) line in the example https://github.com/google/objax/blob/master/examples/classify/img/cifar10_simple.py

    Instead, I replaced it with

    for grad, p in zip(g, model_vars):
          p.value -= lr * grad   
    

    and then supplied model.vars() as an argument to train_op. However, I received an error: objax.variable.VarCollection is not a valid JAX type. Can someone help me with this issue? Here is a minimal working example which reproduces the error.

    import random
    import numpy as np
    import tensorflow as tf
    from objax.zoo.wide_resnet import WideResNet
    
    # Data
    (X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data()
    X_train = X_train.transpose(0, 3, 1, 2) / 255.0
    X_test = X_test.transpose(0, 3, 1, 2) / 255.0
    
    # Model
    model = WideResNet(nin=3, nclass=10, depth=28, width=2)
    #opt = objax.optimizer.Adam(model.vars())
    predict = objax.Jit(lambda x: objax.functional.softmax(model(x, training=False)),
                        model.vars())
    # Losses
    def loss(x, label):
        logit = model(x, training=True)
        return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
    
    gv = objax.GradValues(loss, model.vars())
    
    def train_op(x, y, model_vars, lr):
        g, v = gv(x, y)
        for grad, p in zip(g, model_vars):
          p.value -= lr * grad   
        return v
    
    
    # gv.vars() contains the model variables.
    train_op = objax.Jit(train_op, gv.vars()) #I deleted opt.vars()
    
    for epoch in range(30):
        # Train
        loss = []
        sel = np.arange(len(X_train))
        np.random.shuffle(sel)
        for it in range(0, X_train.shape[0], 64):
            loss.append(train_op(X_train[sel[it:it + 64]], Y_train[sel[it:it + 64]].flatten(), model.vars(), 4e-3 if epoch < 20 else 4e-4)) #I added model.vars() 
    
    opened by RXZ2020 11
  • Enforcing positivity (or other transformations) of TrainVars

    Enforcing positivity (or other transformations) of TrainVars

    Hi,

    Is it possible to declare constraints on trainable variables, e.g. forcing them to be positive via an exponential or softplus transformation?

    In an ideal world, we would be able to write something like: self.variance = objax.TrainVar(np.array(1.0), transform=positive)

    Thanks,

    Will

    p.s. thanks for the great work on objax so far, it's a pleasure to use.

    opened by wil-j-wil 10
  • Training state as a Module attribute

    Training state as a Module attribute

    As mentioned in a Twitter thread, I am curious about the decision to propagate training state through the call() chain. From my perspective this approach adds more boilperplate code, and more chance of making a mistake (not propagating the state to a few instances of a module with a BN or dropout layer, etc). If the state changed every call like the input data, it would make more sense to pass it with every forward, but I can't think of cases where that is common? For small models it doesn't make much difference, but as they grow with more depth and breadth of submodules, the extra args are more noticeable.

    I feel one of the major benefits of an OO abstraction for NN is being able to push some attributes like this into the class structure vs forcing it to be forwarded through every call in a functional manner. I sit in the middle ground (pragmatic) of OO vs functional. Hidden state can be problematics, but worth it if it keeps interfaces clean.

    Besides TF/Keras, most DL libs managetraining state as module attr or some sort of context

    • PyTorch - nn.Module has a self.training attribute, recursively set on train()/eval() calls on the model/modules - https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval
    • MxNet Gluon - a context manager sets scope with autograd.train_mode() with autograd.predict_mode() - https://gluon.mxnet.io/chapter03_deep-neural-networks/mlp-dropout-gluon.html
    • Swift for TF - a thread-local context holds learningPhase - https://www.tensorflow.org/swift/api_docs/Structs/Context

    It should be noted that Swift for TF started out Keras and objax like with the training state passed through call().

    Disclaimer: I like PyTorch, I do quite a bit of work with that framework. It's not perfect but I feel they really did a good job in terms of interface, usibility, evolution of the API. I've read some other comments here and acknowledge the 'we don't want to be like framework/lib X, or Y just because. If you disagree go fork yourself'. Understood, any suggestions I make are not just to be like X, but to bring elemtents of X that work really well to improve this library.

    I currently maintain some PyTorch model collections, https://github.com/rwightman/pytorch-image-models and https://github.com/rwightman/efficientdet-pytorch as examples. I'm running into a cost ($$) wall with experiments supporting my OS work and experiments re GPU. TPU costing is starting to look far more attractive. PyTorch XLA is not proving to be a great option but JAX with a productive interface looks like it could be a winning solution with even more flexibility .

    I'm willing to contribute code for changes like this, but at this point it's matter of design philosophy :)

    opened by rwightman 9
  • Implementing 2 phases DP-SGD

    Implementing 2 phases DP-SGD

    This PR implements a two-phase algorithm for per-sample gradient clipping with the goal of improving memory efficiency for the training of private deep models. The two steps are: (1) accumulate the norms of the gradient per sample and (2) use those norm values to perform a weighted backward pass that is equivalent to per-sample clipping. The user can choose whether to use this new algorithm or the currently implemented one through a boolean argument.

    The unit-tests have been adapted to check results for both algorithms.

    Let me know if this fits well!

    opened by lberrada 7
  • Give better error message when calling Parallel() without replicate()

    Give better error message when calling Parallel() without replicate()

    Currently if you forget to call replicate() on a Parallel module, it dies somewhere in JaX land in between the 5th and 6th circles of hell. This error makes it possible to understand what's going on and find your way back.

    opened by carlini 7
  • Naming of the `GradValues` function

    Naming of the `GradValues` function

    If I understand right, GradValues essentially does two things: computing gradients and computing model final values.

    So why not split it into two functions? Or if we keep the current form, could we name it GradAndValuesFn? Just thinking this is a prominent function and want to keep it the easiest for people beginning to use the framework. An easy name as fit() and predict() made scikit-learn.

    opened by jli05 6
  • Explicit padding mode

    Explicit padding mode

    It looks like objax currently limits padding to one of VALID or SAME. This prevents the ability to use explicit padding and would prevent compatibility with models from PyTorch, Gluon that only support explicit (symmetric) padding without adding extra Pad layers to the model.

    It'd be nice to at minimum add the ability to support TF style explicit padding (specify both sides of every dim), the underlying jax conv impl is able to receive a [[0, 0], [pad_beg, pad_end],[pad_beg, pad_end], [0, 0]] spec like other low level TF conv.

    Even nicer would be a simplificed, per-spatial dim symmetric values like PyTorch, Gluon [pad_h, pad_w] or just pad . My default for most 2D convnets in PyTorch is to use pad = ((stride - 1) + dilation * (kernel_size - 1)) // 2, which results in a 'same-ish' padding value. This can always be done on top of the full low/high padding sequence above.

    Some TF models explicitly work around the limitations of SAME padding. By limitations, I mean the fact that you end up with input dependent padding that can be aysmmetric and shift your feature maps relative to each other in a manner that varies as you change your input size. https://github.com/tensorflow/models/blob/146a37c6663e4a249e02d3dff0087b576e3dc3a1/research/deeplab/core/xception.py#L81-L201

    Possible interfaces:

    • padding : Union[ConvPadding, Sequence[Tuple[int, int]]] (like conv_general_dilated but with the enum for valid/same)

    • Add more modes the enum and associated values for those that need it via a dataclass

    class PaddingType(enum.Enum):
      """An Enum holding the possible padding values for convolution modules."""
        SAME = 'SAME'
        VALID = 'VALID'
        RAW = 'RAW'  # specify padding as seq of high/low tuples
        SYM = 'SYM'  # specify symmetric padding for spatial dim as tuple for H, W or single int
    
    @dataclass
    class Padding:
        type: PaddingType = PaddingType.SAME
        value: Union[Sequence[Tuple[int, int]], Tuple[int, int], int] = None
    
        @classmethod
        def same(cls):
            return Padding(PaddingType.SAME)
    
        @classmethod
        def valid(cls):
            return Padding(PaddingType.VALID)
    
        @classmethod
        def raw(cls, value: Sequence[Tuple[int, int]]):
            return Padding(PaddingType.RAW, value=value)
    
        @classmethod
        def sym(cls, value: Union[Tuple[int, int], int]):
            return Padding(PaddingType.SYM, value=value)
    
    feature request 
    opened by rwightman 6
  • `objax.variable.VarCollection.update` not compliant with key-value assignment

    `objax.variable.VarCollection.update` not compliant with key-value assignment

    Hi everyone! Thanks for the awesome work with objax and the JAX environment, and happy holidays!

    I'm trying to load some VarCollection and/or Dict[str, jnp.DeviceArray] params into the model.vars() which is a VarCollection class, and I can do so by:

    for key, value in new_params.items():
        model.vars()[key].assign(value)
    

    But I'd expect objax.variable.VarCollection.update to work the same way e.g.

    model.vars().update(new_params)
    

    And the later doesn't work while the first one does, not sure if it's because that's not the intended behavior for VarCollection.update or if I'm doing anything wrong... But just the first one works, which for the moment is fine for what I need, but wanted to mention this just in case there's something not working as expected.

    opened by alvarobartt 1
  • `objax.variable.VarCollection.update` fails when passing `Dict[str, Any]`

    `objax.variable.VarCollection.update` fails when passing `Dict[str, Any]`

    Hi everyone! Thanks for the awesome work with objax and the JAX environment, and happy holidays!

    I was playing around for objax for a bit, and realized that if you try to update the model.vars() which is a VarCollection using the VarCollection.update method overwriting the default dict.update method, if what you pass to the function is a Python dictionary and not a VarCollection it fails, as it's being cast into a Python list, and then we're trying to loop over the items of a list as if it was a Python dictionary, so it throws a ValueError: too many values to unpack (expected 2).

    https://github.com/google/objax/blob/53b391bfa72dc59009c855d01b625049a35f5f1b/objax/variable.py#L311-L318

    Is this intended? Shouldn't VarCollection.update just loop over classes that allow .items()?

    opened by alvarobartt 0
  • Update nn.rst

    Update nn.rst

    The channel number for 'in' is currently set as c which is incorrect because c is referring to the output channel number. Instead this needs to be set as t (which is the variable that iterates over the input channel numbers). in[n,c,i+h,j+w] should be changed to in[n,t,i+h,j+w]

    opened by divyas248 1
  • pmean inside objax.parallel causes multithreading deadlock for more than 2 gpus

    pmean inside objax.parallel causes multithreading deadlock for more than 2 gpus

    Hi, I've noticed a problem, where I'd like to ask for your expertise. I'm not entirely sure if it is an objax problem or rather a Jax problem under the hood, but as it is triggered by objax commands I'll post it here.

    Description

    In particular, when combining objax.Parallel and objax.functional.pmean (as done in this tutorial) I encounter problems with more than 2 GPUs (with 2 GPUs it works fine). It results in a deadlock situation, where nothing happens anymore. If I understand the tutorial correctly, the pmean is necessary to average the gradients of all cards.

    Minimal reproducible example

    import objax
    import numpy as np
    from objax.zoo.resnet_v2 import ResNet18
    from jax import numpy as jnp, device_count
    from tqdm import tqdm
    
    
    if __name__ == "__main__":
        print(f"Num devices: {device_count()}")
        model = ResNet18(3, 1)
        opt = objax.optimizer.SGD(model.vars())
    
        @objax.Function.with_vars(model.vars())
        def loss(x, label):
            return objax.functional.loss.mean_squared_error(
                model(x, training=True), label
            ).mean()
    
        gv = objax.GradValues(loss, model.vars())
    
        train_vars = model.vars() + gv.vars() + opt.vars()
    
        @objax.Function.with_vars(train_vars)
        def train_op(
            image_batch,
            label_batch,
        ):
    
            grads, loss = gv(image_batch, label_batch)
            # grads = objax.functional.parallel.pmean(grads) # this line
            # loss = objax.functional.parallel.pmean(loss) # and this line
            loss = loss[0]
            opt(1e-3, grads)
            return loss, grads
    
        train_op = objax.Parallel(train_op, reduce=jnp.mean, vc=train_vars)
    
        with (train_vars).replicate():
            for _ in tqdm(range(10), total=10):
                data = jnp.array(np.random.randn(512, 3, 224, 224))
                label = jnp.zeros((512, 1))
                loss, grads = train_op(data, label)
    
    

    Whenever you comment in the two lines with pmean the program gets stuck. However, if I understood it correctly, this is necessary to get the average of the gradients over all cards.

    Error traces

    As with most deadlock bugs you don't get an error stack trace. However, I have two clues that I've found so far. One is that if this is uncommented, the following appears:

    2022-08-22 14:55:46.462557: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
    2022-08-22 14:55:48.543291: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:36] Thread is unstuck! Warning above was a false-positive. Perhaps the timeout is too short.
    

    The other is that if I manually interrupt it with ctrl+c I got this lengthy stacktrace

    Setup

    We use 4 NVIDIA A40 GPUs with CUDA Version 11.7 (Driver Version 515.65.01), cudnn 8.2.1.32, jax version 0.3.15, objax version 1.6.0

    opened by a1302z 3
Releases(v1.6.0)
  • v1.6.0(Feb 1, 2022)

  • v1.4.0(Apr 1, 2021)

    • Added prototype of ducktyping of Objax variables as JAX arrays
    • Added prototype of automatic variable tracing
    • Added learning rate scheduler
    • Various bugfixes
    Source code(tar.gz)
    Source code(zip)
  • v1.3.1(Feb 3, 2021)

  • v1.3.0(Jan 29, 2021)

    • Feature: Improved error messages overall
    • Feature: Improved BatchNorm numerical stability
    • Feature: Objax2Tf for serving objax using TensorFlow
    • Feature: New API objax.optimizer.ExponentialMovingAverageModule for easy moving average of a model
    • Feature: Automatic broadcasting of scalars for objax.Parallel
    • Feature: New optimizer: LARS
    • Feature: New API added to functional (lax.scan)
    • Feature: Modules can be printed to nicely readable text now (repr)
    • Feature: New interpolate API (for images)
    • Bugfix: make objax.Sequential work with latest JAX
    Source code(tar.gz)
    Source code(zip)
  • v1.2.0(Nov 2, 2020)

    • Feature: Improved error messages.

    • Feature: Extended syntax: allow assigning TrainVar without TrainRef for direction experimentation.

    • Feature: Extended padding options or pad and convolution.

    • Feature: Modified ResNet_V2 to be Keras compatible.

    • Feature: Defaults can be overridden in call for Adam, Momentum.

    • BugFix: Layer norm initialization in GPT-2.

    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
(ICCV 2021) ProHMR - Probabilistic Modeling for Human Mesh Recovery

ProHMR - Probabilistic Modeling for Human Mesh Recovery Code repository for the paper: Probabilistic Modeling for Human Mesh Recovery Nikos Kolotouros

Nikos Kolotouros 209 Dec 13, 2022
Official implementation for “Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior”

HEP Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior Implementation Python3 PyTorch=1.0 NVIDIA GPU+CUDA Training process The

FengZhang 34 Dec 04, 2022
190 Jan 03, 2023
Open Source Light Field Toolbox for Super-Resolution

BasicLFSR BasicLFSR is an open-source and easy-to-use Light Field (LF) image Super-Ressolution (SR) toolbox based on PyTorch, including a collection o

Squidward 50 Nov 18, 2022
"Neural Turing Machine" in Tensorflow

Neural Turing Machine in Tensorflow Tensorflow implementation of Neural Turing Machine. This implementation uses an LSTM controller. NTM models with m

Taehoon Kim 1k Dec 06, 2022
Codes for paper "Towards Diverse Paragraph Captioning for Untrimmed Videos". CVPR 2021

Towards Diverse Paragraph Captioning for Untrimmed Videos This repository contains PyTorch implementation of our paper Towards Diverse Paragraph Capti

Yuqing Song 61 Oct 11, 2022
这是一个yolox-keras的源码,可以用于训练自己的模型。

YOLOX:You Only Look Once目标检测模型在Keras当中的实现 目录 性能情况 Performance 实现的内容 Achievement 所需环境 Environment 小技巧的设置 TricksSet 文件下载 Download 训练步骤 How2train 预测步骤 Ho

Bubbliiiing 64 Nov 10, 2022
Bytedance Inc. 2.5k Jan 06, 2023
[ICRA 2022] An opensource framework for cooperative detection. Official implementation for OPV2V.

OpenCOOD OpenCOOD is an Open COOperative Detection framework for autonomous driving. It is also the official implementation of the ICRA 2022 paper OPV

Runsheng Xu 322 Dec 23, 2022
A Partition Filter Network for Joint Entity and Relation Extraction EMNLP 2021

EMNLP 2021 - A Partition Filter Network for Joint Entity and Relation Extraction

zhy 127 Jan 04, 2023
This repository contains Prior-RObust Bayesian Optimization (PROBO) as introduced in our paper "Accounting for Gaussian Process Imprecision in Bayesian Optimization"

Prior-RObust Bayesian Optimization (PROBO) Introduction, TOC This repository contains Prior-RObust Bayesian Optimization (PROBO) as introduced in our

Julian Rodemann 2 Mar 19, 2022
A framework that allows people to write their own Rocket League bots.

YOU PROBABLY SHOULDN'T PULL THIS REPO Bot Makers Read This! If you just want to make a bot, you don't need to be here. Instead, start with one of thes

543 Dec 20, 2022
Python library for science observations from the James Webb Space Telescope

JWST Calibration Pipeline JWST requires Python 3.7 or above and a C compiler for dependencies. Linux and MacOS platforms are tested and supported. Win

Space Telescope Science Institute 386 Dec 30, 2022
Object Database for Super Mario Galaxy 1/2.

Super Mario Galaxy Object Database Welcome to the public object database for Super Mario Galaxy and Super Mario Galaxy 2. Here, we document all object

Aurum 9 Dec 04, 2022
Video-based open-world segmentation

UVO_Challenge Team Alpes_runner Solutions This is an official repo for our UVO Challenge solutions for Image/Video-based open-world segmentation. Our

Yuming Du 84 Dec 22, 2022
PyTorch implementation of ICLR 2022 paper PiCO: Contrastive Label Disambiguation for Partial Label Learning

PiCO: Contrastive Label Disambiguation for Partial Label Learning This is a PyTorch implementation of ICLR 2022 Oral paper PiCO; also see our Project

王皓波 147 Jan 07, 2023
Does Pretraining for Summarization Reuqire Knowledge Transfer?

Pretraining summarization models using a corpus of nonsense

Approximately Correct Machine Intelligence (ACMI) Lab 12 Dec 19, 2022
Recreate CenternetV2 based on MMDET.

Introduction This project is trying to Recreate CenternetV2 based on MMDET, which is proposed in paper Probabilistic two-stage detection. This project

25 Dec 09, 2022
Differentiable Simulation of Soft Multi-body Systems

Differentiable Simulation of Soft Multi-body Systems Yi-Ling Qiao, Junbang Liang, Vladlen Koltun, Ming C. Lin [Paper] [Code] Updates The C++ backend s

YilingQiao 26 Dec 23, 2022
QuakeLabeler is a Python package to create and manage your seismic training data, processes, and visualization in a single place — so you can focus on building the next big thing.

QuakeLabeler Quake Labeler was born from the need for seismologists and developers who are not AI specialists to easily, quickly, and independently bu

Hao Mai 15 Nov 04, 2022