Tutorials | Install | Documentation | Philosophy

This is not an officially supported Google product.

Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX -- a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.

This is the developer repository of Objax, there is very little user documentation here, for the full documentation go to objax.readthedocs.io.

You can find READMEs in the subdirectory of this project, for example:

User installation guide

You install Objax using pip as follows:

pip install --upgrade objax

Objax supports GPUs but assumes that you already have some version of CUDA installed. Here are the extra steps:

# Update accordingly to your installed CUDA version
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`

Useful environment configurations

Here are a few useful options:

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)

Testing your installation

You can test your installation by running the code below:

import jax
import objax

print(f'Number of GPUs {jax.device_count()}')

x = objax.random.normal(shape=(100, 4))
m = objax.nn.Linear(nin=4, nout=5)
print('Matrix product shape', m(x).shape)  # (100, 5)

x = objax.random.normal(shape=(100, 3, 32, 32))
m = objax.nn.Conv2D(nin=3, nout=4, k=3)
print('Conv2D return shape', m(x).shape)  # (100, 4, 32, 32)

Typically if you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.

Runing code examples

Clone the code repository:

git clone https://github.com/google/objax.git
cd objax/examples

Citing Objax

To cite this repository:

  author = {{Objax Developers}},
  title = {{Objax}},
  url = {https://github.com/google/objax},
  version = {1.2.0},
  year = {2020},

Developer documentation

Here is information about development setup and a guide on adding new code.

  • v1.6.0(Feb 1, 2022)

  • v1.4.0(Apr 1, 2021)

    • Added prototype of ducktyping of Objax variables as JAX arrays
    • Added prototype of automatic variable tracing
    • Added learning rate scheduler
    • Various bugfixes
    Source code(tar.gz)
    Source code(zip)
  • v1.3.1(Feb 3, 2021)

  • v1.3.0(Jan 29, 2021)

    • Feature: Improved error messages overall
    • Feature: Improved BatchNorm numerical stability
    • Feature: Objax2Tf for serving objax using TensorFlow
    • Feature: New API objax.optimizer.ExponentialMovingAverageModule for easy moving average of a model
    • Feature: Automatic broadcasting of scalars for objax.Parallel
    • Feature: New optimizer: LARS
    • Feature: New API added to functional (lax.scan)
    • Feature: Modules can be printed to nicely readable text now (repr)
    • Feature: New interpolate API (for images)
    • Bugfix: make objax.Sequential work with latest JAX
    Source code(tar.gz)
    Source code(zip)
  • v1.2.0(Nov 2, 2020)

    • Feature: Improved error messages.

    • Feature: Extended syntax: allow assigning TrainVar without TrainRef for direction experimentation.

    • Feature: Extended padding options or pad and convolution.

    • Feature: Modified ResNet_V2 to be Keras compatible.

    • Feature: Defaults can be overridden in call for Adam, Momentum.

    • BugFix: Layer norm initialization in GPT-2.

    Source code(tar.gz)
    Source code(zip)
