Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Overview

Elegy

PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status


Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
  • Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information take a look at the Documentation.

Installation

Install Elegy using pip:

pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a very simple interface you can use by implementing following steps:

1. Define the architecture inside a Module. We will use Flax Linen for this example:

import flax.linen as nn
import jax

class MLP(nn.Module):
    @nn.compact
    def call(self, x):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

import elegy, optax

model = elegy.Model(
    module=MLP(),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-5),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Quick Start: Low-level API

In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the test_step to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        self,
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True we should initialize our parameters
    ):  
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10]
            )
            b = jax.random.uniform(rng.next(), shape=[1])
        else:
            w, b = states.net_params
        # model
        logits = jnp.dot(x, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our LinearClassifier with an optimizer:

model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
    ):
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
        else:
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout": rng.next()}, mutable=True
            )
        net_states, net_params = variables.pop("params")
        
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

        logs = dict(accuracy=accuracy, loss=loss)
        return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info

Examples

To run the examples first install some required packages:

pip install -r examples/requirements.txt

Now run the example:

python examples/flax_mnist_vae.py 

Contributing

Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.

About Us

We are some friends passionate about ML.

License

Apache

Citing Elegy

To cite this project:

BibTeX

@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.5.0},
year = {2020},
}

Where the current version may be retrieved either from the Release tag or the file elegy/__init__.py and the year corresponds to the project's release year.

Comments
  • Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    As noted below, this PR contains the following features:

    • It turns Elegy into a framework agnostic library by removing the dependencies between elegy.Model and elegy.Module, it proposes the GeneralizedModule API and implements it for Flax, Haiku, Elegy Module types, and regular python functions.
    • It introduces a new low-level API similar to Pytorch Lightning that lets users manually override the core parts of the training loop when maximal flexibility is required.
    • General changes that enable the framework-agnostic mindset.
    • Many quality of life changes like standardization of hooks, simplification of the Module system, etc.

    Tasks:

    • [x] Create hooks module
    • [x] Refactor Model with low-level API and remove Module dependencies
    • [x] Refactor Module to use hooks
    • [x] Create GeneralizedModule and GeneralizedOptimizer Inferfaces
    • [x] Implement GeneralizedModule for flax.linen.Module
    • [x] Implement GeneralizedModule for elegy.Module
    • [x] Implement GeneralizedModule for haiku.Module
    • [x] Implement GeneralizedOptimizer for optax.GradientTransformation
    • [x] Implement GeneralizedOptimizer for elegy.Optimizer
    • [x] Fix Model.summary
    • [x] Fix tests
    • [x] Fix examples
    • [ ] Fix README
    • [ ] Fix guides
    • [ ] Fix docstrings
    opened by cgarciae 27
  • WGAN-GP low-level API example

    WGAN-GP low-level API example

    A more extensive example using the new low-level API: Wasserstein-GAN with Gradient Penalty (WGAN-GP) trained on the CelebA dataset.

    Some good generated images: epoch-0079 epoch-0084 epoch-0089

    Some notes:

    • I first tried to train a DCGAN which uses binary crossentropy but I've run into balancing issues. The discriminator quickly becomes too good so that the generator does not learn anything. The same model implemented in PyTorch or TensorFlow works. Most modern GANs don't use the WGAN loss anymore, most use BCE.
    • I'm still in favor of making Module.apply() return init(). It's just too much boilerplate to use an if-else every time. I avoided it by manually calling wgan.states = wgan.init(...) after model instantiation which I think is also not nice.
    • Can we make Module.apply() accept params and states separately instead of collections. It's annoying having to construct a dict {'params':params, 'states':states} every time
    • It would be nice if elegy.States was a dict so that the user can decide by themself what to put into it. With GANs where you have to manage generator and discriminator states separately one has to always split them like (g_states, d_states) = net_states which is again annoying
    • Model.save() fails on this model. Partially due to the extra jitted functions but even when I remove them, cloudpickle chokes on _HooksContext

    @cgarciae I'm not completely sure I've used the low-level API correctly, maybe you can take a closer look?

    opened by alexander-g 11
  • Add learning rate logging

    Add learning rate logging

    Implements the same functionality from #131 using only minor modifications to elegy.Optimizer.

    • [x] Add lr_schedule and steps_per_epoch to Optimizer.
    • [x] Implement Optimizer.get_effective_learning_rate
    • [x] Copy logging code from #131
    • [x] Add documentation

    @alexander-g Here is a proposal that is a bit simpler, closer to what I mentioned in #124. What do you think? @charlielito should we log the learning rate automatically if available or should we create a Callback?

    opened by cgarciae 9
  • Question: how to set the random state when calling model.predict(...)

    Question: how to set the random state when calling model.predict(...)

    Not sure if this is the right place to post this...

    I have built and trained a VAE. When calling model.predict(x=test_set), I would like to make multiple predictions for each item in the test set (because VAE's are probabilistic). That way I can look at the distribution of predictions for each item in the test_set.

    The call() for the VAE includes the line
    intrinsic_latents = mean + stds * jax.random.normal(self.next_key(), mean.shape).

    I haven't been able to find an explanation for how self.next_key() works or how to change the random seed on each call so that I can get different predictions. I could rewrite the code so that random seeds are explicitly passed, but I assume there is some functionality build into elegy to make this easy?

    Could someone explain how this works, or point me to the documentation explaining it?

    Thanks!

    opened by jfcrenshaw 8
  • Examples Cleanup

    Examples Cleanup

    • refactored examples/imagenet/resnet_imagenet.py to accept parameters instead of modifying them inside the script
    • added README.md for examples/imagenet/
    • removed unnecessary Lambda class from examples/mnist.py
    • moved global average pooling in examples/mnist_conv.py before the Linear layer
    opened by alexander-g 7
  • Resnet

    Resnet

    • ResNet model architecture and an example for training on ImageNet
      • code is mostly adapted from the flax library
      • pretrained ResNet50 with 76.5% accuracy
      • pretrained ResNet18 with 68.7% accuracy
    • Experimental support for mixed precision: previously all layers set their parameters' dtype to the input's dtype. This is incorrect, for numerical stability reasons all parameters should be float32 even when performing float16 computations. See more here.
    • Some issues I had during training:
      • There seems to be a memory leak during training, RAM constantly increased
      • I had to use smaller batch sizes than when training with flax or with TensorFlow before maxing out GPU memory (64 instead of 128 for ResNet50 on a RTX2080Ti). This might be of course due to a mistake in my code, but the number of parameters is identical to the flax and PyTorch versions, so I think it might be somewhere else
    opened by alexander-g 7
  • [Bug] Problem with computing metrics

    [Bug] Problem with computing metrics

    Describe the bug Hi, when I am using the fit function I have an error message that the update function is not provided with y_true and y_pred. It seems to be coming from the metrics of the model, because if I comment the metrics line I have no error

    TypeError: update() missing 2 required positional arguments: 'y_true' and 'y_pred'
    

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import jax
    import jax.numpy as jnp
    import ml_collections
    import numpy as np
    import optax
    import elegy as eg
    
    
    class eCNN(eg.Module):
        """A simple CNN model."""
    
        @eg.compact
        def __call__(self, x):
            x=eg.Conv(10,kernel_size=(10,))(x)
            x=jax.nn.relu(x)
            x = eg.Linear(1)(x)
            x=jax.nn.sigmoid(x)
            return x
    
    n=200
    X_train = np.random.rand(n*100).reshape(n,100)
    y_train = np.random.rand(n).reshape(n,1)
    print(X_train.shape)
    print(y_train.shape)
    
    model = eg.Model(
        module=eCNN(),
        loss=[
            eg.losses.MeanSquaredError(),
        ],
        metrics=eg.metrics.MeanSquareError(),  #Line to be commented to get rid of the error
        optimizer=optax.rmsprop(1e-3),
    )
    
    model.fit(X_train,y_train,
        epochs=10,
        batch_size=20,
        #validation_data=0.1,
        shuffle=False,
        callbacks=[eg.callbacks.TensorBoard("summaries")]
        )
    

    Library Info Please provide os info and elegy version.

    import elegy
    print(elegy.__version__) 
    # 0.8.4
    
    bug 
    opened by organic-chemistry 6
  • Multi-gpu with pmap docs

    Multi-gpu with pmap docs

    One of the selling points of jax is the pmap transformation, but best practices around actually getting your training loop parallelizable still is confusing. What is elegy's story around multigpu training? Is it possible to get to pytorch-lightning like api as a single arg to model.fit?

    opened by sooheon 6
  • SCCE fix for bug in Jax<0.2.7

    SCCE fix for bug in Jax<0.2.7

    Small fix for a bug in Jax<0.2.7 where jax.lax.take_along_axis gives incorrect results for uint8 indices. Very relevant for semantic segmentation.

    Alternatively consider updating Jax

    opened by alexander-g 6
  • Dataset & DataLoader

    Dataset & DataLoader

    Dataset and parallel DataLoader API similar to PyTorch. Can be used with Model.fit()

    class MyDataset(elegy.data.Dataset):
        def __len__(self):
            return 128
    
        def __getitem__(self, i):
            #dummy data
            return np.random.random([224, 224, 3]),  np.random.randint(10)
    
    ds     = MyDataset()
    loader = elegy.data.DataLoader(ds, batch_size=8, n_workers=8, worker_type='thread', shuffle=True)
    
    batch = next(iter(loader))
    assert batch[0].shape == (8,224,224,3)
    assert batch[1].shape == (8,)
    assert len(loader) == 16
    
    model.fit(loader, epochs=10)
    
    opened by alexander-g 6
  • Implemented BinaryCrossentropy metric

    Implemented BinaryCrossentropy metric

    Updates:

    • Created BinaryCrossentropy metric
    • Created basic tests for BinaryCrossentropy metric (passing)
    • Created docs for BinaryCrossentropy metric
    • Refactored main docs by balancing files and correcting language typos
    documentation 
    opened by sebasarango1180 6
  • use poetry-core

    use poetry-core

    poetry-core is intended to be a light weight, fully compliant, self-contained package allowing PEP 517 compatible build frontends to build Poetry managed projects.

    Using poetry-core allows distribution packages to depend only on the build backend.

    opened by dotlambda 0
  • Documentation/API reference not accessible via project website[Bug]

    Documentation/API reference not accessible via project website[Bug]

    Hi, It looks like I can't really access the API reference for Elegy. The corresponding link on the project's website simply takes me back to the main page (https://poets-ai.github.io/elegy/). Any idea what's up?

    bug 
    opened by geomlyd 0
  • [Bug] elegy does not work with latest haiku version

    [Bug] elegy does not work with latest haiku version

    Describe the bug When I type 'import elegy' I get this error

     File "/home/kpmurphy/mambaforge/lib/python3.10/site-packages/elegy/generalized_module/haiku_module.py", line 4, in <module>
        from haiku._src.base import current_bundle_name
    

    Minimal code to reproduce

    import elegy
    

    Expected behavior A clear and concise description of what you expected to happen.

    Library Info Please provide os info and elegy version.

    >> 
    >>> jax.__version__
    '0.2.28'
    >>> haiku.__version__
    '0.0.9.dev'
    >>> elegy.__version__. #  elegy-0.5.0-py3-none-any.whl 
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    NameError: name 'elegy' is not defined
    >>> 
    

    Screenshots

    Screen Shot 2022-10-03 at 2 33 21 PM

    Additional context Add any other context about the problem here.

    bug 
    opened by murphyk 5
  • CSVLogger iteration over a 0-d array

    CSVLogger iteration over a 0-d array

    Describe the bug When using the CSVLogger callback, elegy crashes at the end of the first epoch.

    Minimal code to reproduce

    import elegy as eg
    import optax
    import numpy as np
    
    x = np.random.randn(64, 1)
    y = np.random.randn(64, 1)
    
    model = eg.Model(
        eg.nn.Linear(1),
        loss=eg.losses.MeanSquaredError(),
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        x,
        y,
        epochs=10,
        callbacks=[
            eg.callbacks.CSVLogger("train.csv"), <-- commenting
        ]
    )
    

    Stack trace:

    Epoch 1/10
    2/2 [==============================] - ETA: 0s - loss: 1.3408 - mean_squared_error_loss: 1.3408
    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/csv.py", line 14, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 465, in fit
        callbacks.on_epoch_end(epoch, epoch_logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/callback_list.py", line 221, in on_epoch_end
        callback.on_epoch_end(epoch, logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in on_epoch_end
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in <genexpr>
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 68, in handle_value
        return '"[%s]"' % (", ".join(map(str, k)))
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py", line 245, in __iter__
        raise TypeError("iteration over a 0-d array")  # same as numpy error
    TypeError: iteration over a 0-d array
    

    Expected behavior Should treat 0-d array as scalar.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context More detailed error information shows the error occurs because the array is a jax DeviceArray and so the test for zero dimensional array uses the line

    is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
    
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py:6 │
    │ 8 in handle_value                                                                                │
    │                                                                                                  │
    │    65 │   │   │   if isinstance(k, six.string_types):                                            │
    │    66 │   │   │   │   return k                                                                   │
    │    67 │   │   │   elif isinstance(k, tp.Iterable) and not is_zero_dim_ndarray:                   │
    │ ❱  68 │   │   │   │   return '"[%s]"' % (", ".join(map(str, k)))                                 │
    │    69 │   │   │   else:                                                                          │
    │    70 │   │   │   │   return k                                                                   │
    │    71                                                                                            │
    │                                                                                                  │
    │ ╭──────────────────────────── locals ─────────────────────────────╮                              │
    │ │ is_zero_dim_ndarray = False                                     │                              │
    │ │                   k = DeviceArray(4.8264385e-05, dtype=float32) │                              │
    │ ╰─────────────────────────────────────────────────────────────────╯                              │
    │                                                                                                  │
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py:245 in │
    │ __iter__                                                                                         │
    │                                                                                                  │
    │   242                                                                                            │
    │   243   def __iter__(self):                                                                      │
    │   244 │   if self.ndim == 0:                                                                     │
    │ ❱ 245 │     raise TypeError("iteration over a 0-d array")  # same as numpy error                 │
    │   246 │   else:                                                                                  │
    │   247 │     return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())            │
    │   248                                                                                            │
    │                                                                                                  │
    │ ╭───────────────────── locals ─────────────────────╮                                             │
    │ │ self = DeviceArray(4.8264385e-05, dtype=float32) │                                             │
    │ ╰──────────────────────────────────────────────────╯                                             │
    ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
    TypeError: iteration over a 0-d array
    
    bug 
    opened by ScottAlexanderCameron 0
  • Metrics ignore

    Metrics ignore "on" keyword arg

    Describe the bug I have an application where I need to output multiple values from a network, which I am doing using a dictionary and using the on keyword argument. This works fine for the loss functions but not for metrics.

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import elegy as eg
    import optax
    import numpy as np
    
    
    def data_generator():
        while True:
            yield (
                np.random.randn(10, 1),
                {"target": {"y": np.random.randn(10, 1)}},
            )
    
    
    class MyModule(eg.Module):
        @eg.compact
        def __call__(self, x):
            return {"y": eg.nn.Linear(1)(x)}
    
    
    model = eg.Model(
        MyModule(),
        loss=eg.losses.MeanSquaredError(on="y"),
        metrics=eg.metrics.MeanAbsoluteError(on="y"),  #  <-- works fine without this line
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        data_generator(),
        steps_per_epoch=10,
        epochs=10,
    )
    

    Stack trace:

    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/metric.py", line 27, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 417, in fit
        tmp_logs = self.train_on_batch(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 617, in train_on_batch
        logs, model = train_step_fn(self, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 412, in _static_train_step
        return model.train_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 306, in train_step
        grads, (logs, model) = grad_fn(params, model, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 278, in loss_fn
        loss, logs, model = model.test_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 248, in test_step
        batch_loss_and_logs.update(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/loss_and_logs.py", line 78, in update
        self.metrics.update(**metrics_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/metrics.py", line 44, in update
        metric.update(**metric_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 83, in update
        values = _mean_absolute_error(preds, target)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 20, in _mean_absolute_error
        target = target.astype(preds.dtype)
    AttributeError: 'dict' object has no attribute 'astype'
    

    Expected behavior Should produce the same result as if the dictionaries are removed and the on arg not specified.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context From my digging the cause seems to be due to the Metric.update() method being called instead of the __call__ method.

    bug 
    opened by ScottAlexanderCameron 0
  • [Bug] Elegy crash on GPU

    [Bug] Elegy crash on GPU

    Describe the bug

    Running mnist_cnn.py in the example dir crash the instance at the end of the first epoch.

    This was previously reported on Colab GPU instance. But I can reproduce this on CLI too.

    Running on CPU does not have this problem.

    Running on eager mode with GPU does not have this problem.

    Minimal code to reproduce

    python mnist_cnn.py
    

    Expected behavior Not stuck.

    Library Info CentOS Linux release 7.6.1810 elegy 0.8.6

    Additional context absl-py==1.2.0 aiohttp==3.8.1 aiosignal==1.2.0 async-timeout==4.0.2 attrs==22.1.0 certifi==2021.10.8 charset-normalizer==2.1.1 chex==0.1.4 click==8.1.3 cloudpickle==1.6.0 colorama==0.4.5 commonmark==0.9.1 cycler==0.11.0 datasets==2.4.0 dill==0.3.5.1 dm-tree==0.1.7 docker-pycreds==0.4.0 einops==0.4.1 elegy==0.8.6 etils==0.7.1 filelock==3.8.0 flax==0.4.2 fonttools==4.36.0 frozenlist==1.3.1 fsspec==2022.7.1 gitdb==4.0.9 GitPython==3.1.27 h5py==3.6.0 huggingface-hub==0.8.1 idna==3.3 importlib-resources==5.9.0 jax==0.3.16 jaxlib==0.3.15+cuda11.cudnn82 kiwisolver==1.4.4 matplotlib==3.5.3 msgpack==1.0.4 multidict==6.0.2 multiprocess==0.70.13 numpy==1.22.3 opt-einsum==3.3.0 optax==0.1.3 packaging==21.3 pandas==1.4.3 pathtools==0.1.2 Pillow==9.2.0 promise==2.3 protobuf==3.20.1 psutil==5.9.1 pyarrow==9.0.0 Pygments==2.13.0 pyparsing==3.0.9 python-dateutil==2.8.2 pytz==2022.2.1 PyYAML==6.0 requests==2.28.1 responses==0.18.0 rich==11.2.0 scipy==1.8.0 sentry-sdk==1.9.5 setproctitle==1.3.2 shortuuid==1.0.9 six==1.16.0 smmap==5.0.0 tensorboardX==2.5.1 toolz==0.12.0 tqdm==4.64.0 treeo==0.0.10 treex==0.6.10 typing_extensions==4.3.0 urllib3==1.26.11 wandb==0.12.21 xxhash==3.0.0 yarl==1.8.1 zipp==3.8.1

    bug 
    opened by jiyuuchc 2
Releases(0.8.6)
  • 0.8.6(Mar 23, 2022)

    🚀 Features

    • Weights and Biases Callback for Elegy
      • PR: #220

    🐛 Fixes

    • Docs typos
      • PR: #222
    • Donate model's memory buffer to jit/pmap functions.
      • PR: #226
    • Lazy load wandb
      • PR: #228
    Source code(tar.gz)
    Source code(zip)
  • 0.8.5(Feb 23, 2022)

  • 0.8.4(Dec 14, 2021)

  • 0.8.3(Dec 13, 2021)

  • 0.8.2(Dec 13, 2021)

  • 0.8.1(Nov 8, 2021)

    Elegy is now based on Treex 🎉

    Changes

    • Remove the module, nn, metrics, and losses from Elegy, instead Elegy reexports these modules from Treex.
    • GeneralizedModule and friends are gone, to use Flax Modules use the elegy.nn.FlaxModule wrapper.
    • Low level API is massively simplified:
      • States is removed, since Model is a pytree all parameters are tracked automatically thanks to Treex / Treeo.
      • All static state arguments (training, initializing) are removed, Modules can simply use self.training to pick their training state and self.initializing() to check whether they are initializing.
      • Signature for pred_step, test_step, and train_step now simply consists of inputs and labels, where labels is a dict that can contain additional keys like sample_weight or class_weight as required by the losses and metrics.
    • Adds the DistributedStrategy class which currently has 3 instances
      • Eager: Runs model in a single device in eager mode (no jit)
      • JIT: Runs model in a single device with jit
      • DataParallel: Run the model in multiple devices using pmap.
    • Adds methods to change the model's distributed strategy:
      • .distributed(strategy = DataParallel): changes the distributed strategy, DataParallel used by default.
      • .local(): changes the distributed strategy to JIT.
      • .eager(): changes the distributed strategy to Eager.
    • Removes the .eager field in favor of the .eager() method.
    Source code(tar.gz)
    Source code(zip)
  • 0.7.4(Jun 1, 2021)

  • 0.7.2(Mar 10, 2021)

  • 0.7.1(Mar 1, 2021)

  • 0.7.0(Feb 22, 2021)

    Features

    • init now only called once internally and required to be called explicitly by the user under certain circumstances
    • summary now uses jax.eval_shape under the hood so its super fast since it doesn't allocate memory or perform any computations on the device.

    Merged pull requests:

    • Fix notebook #166 (cgarciae)
    • Single Initialization: Removes the current progressive initialization in favor of a single underlying call to init_step. #165 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.6.0(Feb 14, 2021)

  • 0.5.0(Feb 8, 2021)

    This version simplifies parts of the low-level API in spirit of what was introduced in 0.4.0 to provide a more homogeneous and simpler experience.

    Merged pull requests:

    • Improve States: uses __dict__ so States works with vars #159 (cgarciae)
    • Simplify API: Cleans-up some API details around Model and Module #158 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.4.1(Feb 3, 2021)

  • 0.4.0(Feb 1, 2021)

    Implemented enhancements:

    • [Feature Request] Monitoring learning rates #124

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(Dec 17, 2020)

    Implemented enhancements:

    • elegy.nn.Sequential docs not clear #107
    • [Feature Request] Community example repo. #98

    Fixed bugs:

    • [Bug] Accuracy from Model.evaluate() is inconsistent with manually computed accuracy #109
    • Exceptions in "Getting Started" colab notebook #104

    Closed issues:

    • l2_normalize #102
    • Need some help for contributing new losses. #93
    • Document Sum #62
    • Binary Accuracy Metric #58
    • Automate generation of API Reference folder structure #19
    • Implement Model.summary #3

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.2.2(Aug 31, 2020)

  • 0.2.1(Aug 25, 2020)

  • 0.2.0(Aug 17, 2020)

  • 0.1.5(Jul 28, 2020)

    • Mean Absolute Percentage Error Implementation @Ciroye
    • Adds elegy.nn.Linear, elegy.nn.Conv2D, elegy.nn.Flatten, elegy.nn.Sequential @cgarciae
    • Add Elegy hooks @cgarciae
    • Improves Tensorboard support @Davidnet
    • Added coverage metrics to CI @charlielito
    Source code(tar.gz)
    Source code(zip)
  • 0.1.4(Jul 24, 2020)

    • Adds elegy.metrics.BinaryCrossentropy @sebasarango1180
    • Adds elegy.nn.Dropout and elegy.nn.BatchNormalization @cgarciae
    • Improves documentation
    • Fixes bug that cause error when using is_training via dependency injection on Model.predict.
    Source code(tar.gz)
    Source code(zip)
  • 0.1.3(Jul 23, 2020)

Code for the ACL2021 paper "Lexicon Enhanced Chinese Sequence Labelling Using BERT Adapter"

Lexicon Enhanced Chinese Sequence Labeling Using BERT Adapter Code and checkpoints for the ACL2021 paper "Lexicon Enhanced Chinese Sequence Labelling

274 Dec 06, 2022
Implementation of the paper Scalable Intervention Target Estimation in Linear Models (NeurIPS 2021), and the code to generate simulation results.

Scalable Intervention Target Estimation in Linear Models Implementation of the paper Scalable Intervention Target Estimation in Linear Models (NeurIPS

0 Oct 25, 2021
[MICCAI'20] AlignShift: Bridging the Gap of Imaging Thickness in 3D Anisotropic Volumes

AlignShift NEW: Code for our new MICCAI'21 paper "Asymmetric 3D Context Fusion for Universal Lesion Detection" will also be pushed to this repository

Medical 3D Vision 42 Jan 06, 2023
DETReg: Unsupervised Pretraining with Region Priors for Object Detection

DETReg: Unsupervised Pretraining with Region Priors for Object Detection Amir Bar, Xin Wang, Vadim Kantorov, Colorado J Reed, Roei Herzig, Gal Chechik

Amir Bar 283 Dec 27, 2022
Official PyTorch implementation of "ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows"

ArtFlow Official PyTorch implementation of the paper: ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows Jie An*, Siyu Huang*, Yibing

123 Dec 27, 2022
Open-Domain Question-Answering for COVID-19 and Other Emergent Domains

Open-Domain Question-Answering for COVID-19 and Other Emergent Domains This repository contains the source code for an end-to-end open-domain question

7 Sep 27, 2022
Genshin-assets - 👧 Public documentation & static assets for Genshin Impact data.

genshin-assets This repo provides easy access to the Genshin Impact assets, primarily for use on static sites. Sources Genshin Optimizer - An Artifact

Zerite Development 5 Nov 22, 2022
This tool uses Deep Learning to help you draw and write with your hand and webcam.

This tool uses Deep Learning to help you draw and write with your hand and webcam. A Deep Learning model is used to try to predict whether you want to have 'pencil up' or 'pencil down'.

lmagne 169 Dec 10, 2022
Build fully-functioning computer vision models with PyTorch

Detecto is a Python package that allows you to build fully-functioning computer vision and object detection models with just 5 lines of code. Inferenc

Alan Bi 576 Dec 29, 2022
NVIDIA Merlin is an open source library providing end-to-end GPU-accelerated recommender systems, from feature engineering and preprocessing to training deep learning models and running inference in production.

NVIDIA Merlin NVIDIA Merlin is an open source library designed to accelerate recommender systems on NVIDIA’s GPUs. It enables data scientists, machine

419 Jan 03, 2023
Compartmental epidemic model to assess undocumented infections: applications to SARS-CoV-2 epidemics in Brazil - Datasets and Codes

Compartmental epidemic model to assess undocumented infections: applications to SARS-CoV-2 epidemics in Brazil - Datasets and Codes The codes for simu

1 Jan 12, 2022
Video Instance Segmentation with a Propose-Reduce Paradigm (ICCV 2021)

Propose-Reduce VIS This repo contains the official implementation for the paper: Video Instance Segmentation with a Propose-Reduce Paradigm Huaijia Li

DV Lab 39 Nov 23, 2022
Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021

Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021 The code for training mCOLT/mRASP2, a multilingua

104 Jan 01, 2023
Several simple examples for popular neural network toolkits calling custom CUDA operators.

Neural Network CUDA Example Several simple examples for neural network toolkits (PyTorch, TensorFlow, etc.) calling custom CUDA operators. We provide

WeiYang 798 Jan 01, 2023
Read number plates with https://platerecognizer.com/

HASS-plate-recognizer Read vehicle license plates with https://platerecognizer.com/ which offers free processing of 2500 images per month. You will ne

Robin 69 Dec 30, 2022
Large Scale Multi-Illuminant (LSMI) Dataset for Developing White Balance Algorithm under Mixed Illumination

Large Scale Multi-Illuminant (LSMI) Dataset for Developing White Balance Algorithm under Mixed Illumination (ICCV 2021) Dataset License This work is l

DongYoung Kim 33 Jan 04, 2023
PRIN/SPRIN: On Extracting Point-wise Rotation Invariant Features

PRIN/SPRIN: On Extracting Point-wise Rotation Invariant Features Overview This repository is the Pytorch implementation of PRIN/SPRIN: On Extracting P

Yang You 17 Mar 02, 2022
HairCLIP: Design Your Hair by Text and Reference Image

Overview This repository hosts the official PyTorch implementation of the paper: "HairCLIP: Design Your Hair by Text and Reference Image". Our single

322 Jan 06, 2023
This repository is the code of the paper "Sparse Spatial Transformers for Few-Shot Learning".

🌟 Sparse Spatial Transformers for Few-Shot Learning This code implements the Sparse Spatial Transformers for Few-Shot Learning(SSFormers). Our code i

chx_nju 38 Dec 13, 2022
A basic reminder tool written in Python.

A simple Python Reminder Here's a basic reminder tool written in Python that speaks to the user and sends a notification. Run pip3 install pyttsx3 w

Sachit Yadav 4 Feb 05, 2022