Elegy
Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.
Main Features
- Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
- Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
- Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
- Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.
For more information take a look at the Documentation.
Installation
Install Elegy using pip:
pip install elegy
For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.
Quick Start: High-level API
Elegy's high-level API provides a very simple interface you can use by implementing following steps:
1. Define the architecture inside a Module. We will use Flax Linen for this example:
import flax.linen as nn
import jax
class MLP(nn.Module):
    @nn.compact
    def call(self, x):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x
2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:
import elegy, optax
model = elegy.Model(
    module=MLP(),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-5),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optax.rmsprop(1e-3),
)
3. Train the model using the fit method:
model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)
Quick Start: Low-level API
In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the test_step to implement a linear classifier in pure jax:
1. Calculate our loss, logs, and states:
class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        self,
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True we should initialize our parameters
    ):  
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10]
            )
            b = jax.random.uniform(rng.next(), shape=[1])
        else:
            w, b = states.net_params
        # model
        logits = jnp.dot(x, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, states.update(net_params=(w, b))
2. Instantiate our LinearClassifier with an optimizer:
model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)
3. Train the model using the fit method:
model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)
Using Jax Frameworks
It is straightforward to integrate other functional JAX libraries with this low-level API:
class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
    ):
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
        else:
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout": rng.next()}, mutable=True
            )
        net_states, net_params = variables.pop("params")
        
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        logs = dict(accuracy=accuracy, loss=loss)
        return loss, logs, states.update(net_params=net_params, net_states=net_states)
More Info
- Getting Started: High-level API tutorial.
- Getting Started: Low-level API tutorial.
- Elegy's Documentation.
- The examples directory.
- What is Jax?
Examples
To run the examples first install some required packages:
pip install -r examples/requirements.txt
Now run the example:
python examples/flax_mnist_vae.py 
Contributing
Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.
About Us
We are some friends passionate about ML.
License
Apache
Citing Elegy
To cite this project:
BibTeX
@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.5.0},
year = {2020},
}
Where the current version may be retrieved either from the Release tag or the file elegy/__init__.py and the year corresponds to the project's release year.
 
 

