JMP is a Mixed Precision library for JAX.

Related tags

Machine Learningjmp
Overview

Mixed precision training in JAX

Test status PyPI version

Installation | Examples | Policies | Loss scaling | Citing JMP | References

Mixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.

This library implements support for mixed precision training in JAX by providing two key abstractions (mixed precision "policies" and loss scaling). Neural network libraries (such as Haiku) can integrate with jmp and provide "Automatic Mixed Precision (AMP)" support (automating or simplifying applying policies to modules).

All code examples below assume the following:

import jax
import jax.numpy as jnp
import jmp

half = jnp.float16  # On TPU this should be jnp.bfloat16.
full = jnp.float32

Installation

JMP is written in pure Python, but depends on C++ code via JAX and NumPy.

Because JAX installation is different depending on your CUDA version, JMP does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install JMP using pip:

$ pip install git+https://github.com/deepmind/jmp

Examples

You can find a fully worked JMP example in Haiku which shows how to use mixed f32/f16 precision to halve training time on GPU and mixed f32/bf16 to reduce training time on TPU by a third.

Policies

A mixed precision policy encapsulates the configuration in a mixed precision experiment.

# Our policy specifies that we will store parameters in full precision but will
# compute and return output in half precision.
my_policy = jmp.Policy(compute_dtype=half,
                       param_dtype=full,
                       output_dtype=half)

The policy object can be used to cast pytrees:

def layer(params, x):
  params, x = my_policy.cast_to_compute((params, x))
  w, b = params
  y = x @ w + b
  return my_policy.cast_to_output(y)

params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
y = layer(params, x)
assert y.dtype == half

You can replace the output type of a given policy:

my_policy = my_policy.with_output_dtype(full)

You can also define a policy via a string, which may be useful for specifying a policy as a command-line argument or as a hyperparameter to your experiment:

my_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
float16 = jmp.get_policy("float16")  # Everything in f16.
half = jmp.get_policy("half")        # Everything in half (f16 or bf16).

Loss scaling

When training with reduced precision, consider whether gradients will need to be shifted into the representable range of the format that you are using. This is particularly important when training with float16 and less important for bfloat16. See the NVIDIA mixed precision user guide [1] for more details.

The easiest way to shift gradients is with loss scaling, which scales your loss and gradients by S and 1/S respectively.

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.
  params = apply_optimizer(params, grads)
  return params

loss_scale = jmp.StaticLossScale(2 ** 15)
for _ in range(num_steps):
  params = train_step(params, loss_scale, ...)

The appropriate value for S depends on your model, loss, batch size and potentially other factors. You can determine this with trial and error. As a rule of thumb you want the largest value of S that does not introduce overflow during backprop. NVIDIA [1] recommend computing statistics about the gradients of your model (in full precision) and picking S such that its product with the maximum norm of your gradients is below 65,504.

We provide a dynamic loss scale, which adjusts the loss scale periodically during training to find the largest value for S that produces finite gradients. This is more convenient and robust compared with picking a static loss scale, but has a small performance impact (between 1 and 5%).

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.

  # You definitely want to skip non-finite updates with the dynamic loss scale,
  # but you might also want to consider skipping them when using a static loss
  # scale if you experience NaN's when training.
  skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)

  if skip_nonfinite_updates:
    grads_finite = jmp.all_finite(grads)
    # Adjust our loss scale depending on whether gradients were finite. The
    # loss scale will be periodically increased if gradients remain finite and
    # will be decreased if not.
    loss_scale = loss_scale.adjust(grads_finite)
    # Only apply our optimizer if grads are finite, if any element of any
    # gradient is non-finite the whole update is discarded.
    params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params)
  else:
    # With static or no loss scaling just apply our optimizer.
    params = apply_optimizer(params, grads)

  # Since our loss scale is dynamic we need to return the new value from
  # each step. All loss scales are `PyTree`s.
  return params, loss_scale

loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))
for _ in range(num_steps):
  params, loss_scale = train_step(params, loss_scale, ...)

In general using a static loss scale should offer the best speed, but we have optimized dynamic loss scaling to make it competitive. We recommend you start with dynamic loss scaling and move to static loss scaling if performance is an issue.

We finally offer a no-op loss scale which you can use as a drop in replacement. It does nothing (apart from implement the jmp.LossScale API):

loss_scale = jmp.NoOpLossScale()
assert loss is loss_scale.scale(loss)
assert grads is loss_scale.unscale(grads)
assert loss_scale is loss_scale.adjust(grads_finite)
assert loss_scale.loss_scale == 1

Citing JMP

This repository is part of the DeepMind JAX Ecosystem, to cite JMP please use the DeepMind JAX Ecosystem citation.

References

[0] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 https://arxiv.org/abs/1710.03740.

[1] "Training With Mixed Precision :: NVIDIA Deep Learning Performance Documentation". Docs.Nvidia.Com, 2020, https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/.

Comments
  • Questions around speedup

    Questions around speedup

    Hi,

    Thanks for creating this amazing library!

    So from what I understood this is the minimal needed to benefit from the JMP speedup

    policy = jmp.Policy(param_dtype=jnp.float32, compute_dtype=jnp.float16, output_dtype=jnp.float16)
    ...
    # creating a network, and creating a loss_fn using the network
    data = policy.cast_to_compute(data)
    params = policy.cast_to_compute(params)
    
    grads = jax.grad(loss_fn)(data, params)
    ...
    grads = policy.cast_to_param(grads)
    

    The loss scale is only needed if we experience NaN or inf values. And thus we should see a difference in the time needed for a jax.grad operation when the data and params are in float16.

    Please correct me if I'm wrong or if I miss anything.

    Considering that, I have some questions:

    • Can we expect to see a 2x or any speedup if we time the jax.grad operation with (params, data) in float16 and in float32 or is the speedup only achieved at scale ?
    • Can we expect to see a 2x or any speedup with small networks and small experiments ? (say 2-layer nets and experiments that only need 2-3 minutes on a single gpu)
    • Can we expect to see a 2x or any speedup with all types of networks or only some with specific architectures ? (say a 50 layer MLP, Transformers, etc)
    • Is it necessary to apply the mixed precision policy to the network and thus to use Haiku for hk.mixed_precision.set_policy or is having the parameters and data in float16 sufficient to have a speedup even with a network created by us ?

    Thank you and have a nice day !

    opened by 1m1ne 3
  • Basic question about the use of my_policy

    Basic question about the use of my_policy

    Hi,

    Thanks for creating this amazing functionality!

    I have a basic question about the use of policy functions to set certain precision levels. As stated in the example of your README.md:

    def layer(params, x):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    I am new to JAX, but the first thing I learned is that JAX likes pure functions. Does the use of my_policy violate the pure functions paradigm?

    Should it become:

    def layer(params, x, my_policy):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    The function can be jitted by using partial.

    from functools import partial
    layer_compiled = jit(partial(layer, my_policy=my_policy))
    

    Fundamental question

    A more fundamental question (which I maybe need to ask at JAX ) is: How should functions as input to functions be handled in JAX?

    func_a(x,w,func):
      ...
    
    func_a_compiled = jit(partial(func_a, func=func_b))
    

    In order to jit this, I came up with the partial solution above. I assume your use of my_policy is valid, as you probably have more experience with JAX. But that creates some magic, which is undesirable for my use case. Is the jit(partial()) solution valid or is there a better way to handle functions as input to functions?

    Have a nice day! J

    opened by JSchuurmans 2
  •  jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    When trying to build my own wheel from the tar-ball of jmp-0.0.2 that has been uploaded to PyPI I get the following error:

    Collecting jmp==0.0.2
      Downloading jmp-0.0.2.tar.gz (13 kB)
        Running command python setup.py egg_info
        Traceback (most recent call last):
          File "<string>", line 1, in <module>
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 55, in <module>
            install_requires=_parse_requirements('requirements.txt'),
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 32, in _parse_requirements
            with open(requirements_txt_path) as fp:
        FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'
    WARNING: Discarding https://files.pythonhosted.org/packages/7c/ba/a6bfcaeedca8551e2fb4054d1fd061a0dd97d26dd44002b3e92d13b51877/jmp-0.0.2.tar.gz#sha256=fdb5cec0d10aab4116c2770f24b2adf4f503fcfbb96ce8ef583e1879bdbf1b9b (from https://pypi.org/simple/jmp/). Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.
    ERROR: Could not find a version that satisfies the requirement jmp==0.0.2 (from versions: 0.0.1, 0.0.2)
    ERROR: No matching distribution found for jmp==0.0.2
    

    After downloading and extracting jmp-0.0.2.tar.gz from PyPI and looking at setup.py, I see that it contains lines that reference requirements.txt:

    install_requires=_parse_requirements('requirements.txt')
    

    however that file is not part of the source distribution and my build fails.

    Oliver

    P.S. Yes I know that there is a wheel for jmp-0.0.2 on PyPI, which I ended up using, but I'm using our Wheels_builder script that will recursively build wheels for dependencies for our systems. My point is that the source distribution of jmp 0.0.2 is incomplete as it lacks the information that is contained in the requirements file.

    opened by ostueker 1
  • Casting numpy array

    Casting numpy array

    I'll first start by saying that I've just starting with JAX, so I might be doing something wrong.

    When I run the following code:

    precision = jmp.get_policy('params=float32,compute=float16,output=float16')
    some_input = np.arange(15).reshape((5, 3))
    @jax.jit
    def some_function(some_input):
        some_input = precision.cast_to_compute(some_input)
        print(some_input.dtype)
        # prints float32
       return a_float16_compute_model(some_input) # fails
    

    It seems that at least on the first (tracing) run, the numpy array doesn't get cast to float16, maybe because it is being treated as a tree, here

    https://github.com/deepmind/jmp/blob/4b94370b8de29b79d6f840b09d1990b91c1afddd/jmp/_src/policy.py#L26

    Surprisingly, if I run

    some_input = np.arange(15).reshape((5, 3)).astype(precision.compute_dtype)
    print(some_input.dtype)
    
    #prints float16
    

    the cast succeeds. I think that the expected behavior is that precision.cast_to_compute(some_input) should return a numpy array with compute_dtype, but I might be missing something.

    opened by yardenas 1
  • [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    An upcoming change to JAX adds a @jit decorator around a number of array operators, in this case division (/). A side effect of that change is that large Python integer constants that overflow an int32 or int64 type may produce an error. The workaround is either to explicitly cast the large constants to a specific type (e.g., np.float64), or in this case we can just do the math in question in classic NumPy since it is computing test expectations.

    cla: yes 
    opened by copybara-service[bot] 0
  • Bump version to `0.0.3.dev`.

    Bump version to `0.0.3.dev`.

    Bump version to 0.0.3.dev.

    This is so if users install from GitHub we can see they are not using a stable version in bug reports:

    $ pip install git+https://github.com/deepmind/jmp
    $ python -c 'import jmp ; print(jmp.__version__)'
    0.0.3.dev
    
    cla: yes 
    opened by copybara-service[bot] 0
  • Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    We are not using 0.0.1 because this version was already used by the previous owner of the pypi package (https://pypi.org/project/jmp/0.0.1/). As such our releases will start at 0.0.2.

    cla: yes 
    opened by copybara-service[bot] 0
  • Replace jax.lax.select with jnp.where

    Replace jax.lax.select with jnp.where

    Thanks for the awesome work!

    This PR fixes an issue where jax.lax.select complains about dtypes not being equal when adjusting the DynamicLossScale.

    The exception's stack trace ends with:

    File ".../jmp/_src/loss_scale.py", line 147, in adjust
        loss_scale = jax.lax.select(
    TypeError: lax.select requires arguments to have the same dtypes, got float32, int32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
    

    My code looks similar to

    scale = jmp.DynamicLossScale(jnp.asarray(2 ** 15))
    ...
    gradients, scale = gradient_fn(..., scale)
    gradients = scale.unscale(gradients)
    gradients_finite = jmp.all_finite(gradients)
    scale = scale.adjust(gradients_finite)  # This line throws the exception
    ...
    
    opened by nlsfnr 6
Releases(v0.0.2)
  • v0.0.2(Apr 15, 2021)

    Initial release of JMP.

    Changelog:

    • Add jmp.Policy abstraction and jmp.get_policy(..) factory.
    • Add jmp.LossScale and three implementations thereof (noop, static and dynamic).
    • Add various utilities (jmp.all_finite) to support common tasks in mixed precision codebases.
    Source code(tar.gz)
    Source code(zip)
Owner
DeepMind
DeepMind
DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective. 10x Larger Models 10x Faster Trainin

Microsoft 8.4k Dec 30, 2022
An implementation of Relaxed Linear Adversarial Concept Erasure (RLACE)

Background This repository contains an implementation of Relaxed Linear Adversarial Concept Erasure (RLACE). Given a dataset X of dense representation

Shauli Ravfogel 4 Apr 13, 2022
Flightfare-Prediction - It is a Flightfare Prediction Web Application Using Machine learning,Python and flask

Flight_fare-Prediction It is a Flight_fare Prediction Web Application Using Machine learning,Python and flask Using Machine leaning i have created a F

1 Dec 06, 2022
An MLOps framework to package, deploy, monitor and manage thousands of production machine learning models

Seldon Core: Blazing Fast, Industry-Ready ML An open source platform to deploy your machine learning models on Kubernetes at massive scale. Overview S

Seldon 3.5k Jan 01, 2023
A single Python file with some tools for visualizing machine learning in the terminal.

Machine Learning Visualization Tools A single Python file with some tools for visualizing machine learning in the terminal. This demo is composed of t

Bram Wasti 35 Dec 29, 2022
Mesh TensorFlow: Model Parallelism Made Easier

Mesh TensorFlow - Model Parallelism Made Easier Introduction Mesh TensorFlow (mtf) is a language for distributed deep learning, capable of specifying

1.3k Dec 26, 2022
AutoTabular automates machine learning tasks enabling you to easily achieve strong predictive performance in your applications.

AutoTabular automates machine learning tasks enabling you to easily achieve strong predictive performance in your applications. With just a few lines of code, you can train and deploy high-accuracy m

Robin 55 Dec 27, 2022
Decision tree is the most powerful and popular tool for classification and prediction

Diabetes Prediction Using Decision Tree Introduction Decision tree is the most powerful and popular tool for classification and prediction. A Decision

Arjun U 1 Jan 23, 2022
Fourier-Bayesian estimation of stochastic volatility models

fourier-bayesian-sv-estimation Fourier-Bayesian estimation of stochastic volatility models Code used to run the numerical examples of "Bayesian Approa

15 Jun 20, 2022
Nevergrad - A gradient-free optimization platform

Nevergrad - A gradient-free optimization platform nevergrad is a Python 3.6+ library. It can be installed with: pip install nevergrad More installati

Meta Research 3.4k Jan 08, 2023
A collection of video resources for machine learning

Machine Learning Videos This is a collection of recorded talks at machine learning conferences, workshops, seminars, summer schools, and miscellaneous

Dustin Tran 1.5k Dec 29, 2022
Classification based on Fuzzy Logic(C-Means).

CMeans_fuzzy Classification based on Fuzzy Logic(C-Means). Table of Contents About The Project Fuzzy CMeans Algorithm Built With Getting Started Insta

Armin Zolfaghari Daryani 3 Feb 08, 2022
Simple, light-weight config handling through python data classes with to/from JSON serialization/deserialization.

Simple but maybe too simple config management through python data classes. We use it for machine learning.

Eren Gölge 67 Nov 29, 2022
A comprehensive repository containing 30+ notebooks on learning machine learning!

A comprehensive repository containing 30+ notebooks on learning machine learning!

Jean de Dieu Nyandwi 3.8k Jan 09, 2023
Python based GBDT implementation

Py-boost: a research tool for exploring GBDTs Modern gradient boosting toolkits are very complex and are written in low-level programming languages. A

Sberbank AI Lab 20 Sep 21, 2022
A Python implementation of the Robotics Toolbox for MATLAB

Robotics Toolbox for Python A Python implementation of the Robotics Toolbox for MATLAB® GitHub repository Documentation Wiki (examples and details) Sy

Peter Corke 1.2k Jan 07, 2023
A Lightweight Hyperparameter Optimization Tool 🚀

The mle-hyperopt package provides a simple and intuitive API for hyperparameter optimization of your Machine Learning Experiment (MLE) pipeline.

Robert Lange 137 Dec 02, 2022
Microsoft Machine Learning for Apache Spark

Microsoft Machine Learning for Apache Spark MMLSpark is an ecosystem of tools aimed towards expanding the distributed computing framework Apache Spark

Microsoft Azure 3.9k Dec 30, 2022
Arquivos do curso online sobre a estatística voltada para ciência de dados e aprendizado de máquina.

Estatistica para Ciência de Dados e Machine Learning Arquivos do curso online sobre a estatística voltada para ciência de dados e aprendizado de máqui

Renan Barbosa 1 Jan 10, 2022
Pandas DataFrames and Series as Interactive Tables in Jupyter

Pandas DataFrames and Series as Interactive Tables in Jupyter Star Turn pandas DataFrames and Series into interactive datatables in both your notebook

Marc Wouts 364 Jan 04, 2023