Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Overview

Diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax is a JAX-based library providing numerical differential equation solvers.

Features include:

  • ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
  • lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit solvers);
  • vmappable everything (including the region of integration);
  • using a PyTree as the state;
  • dense solutions;
  • multiple adjoint methods for backpropagation;
  • support for neural differential equations.

From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.

Installation

pip install diffrax

Requires Python >=3.7 and JAX >=0.2.27.

Documentation

Available at https://docs.kidger.site/diffrax.

Quick example

from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

Here, Dopri5 refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems.

Citation

If you found this library useful in academic research, please cite: (arXiv link)

@phdthesis{kidger2021on,
    title={{O}n {N}eural {D}ifferential {E}quations},
    author={Patrick Kidger},
    year={2021},
    school={University of Oxford},
}

(Also consider starring the project on GitHub.)

Comments
  • [WIP] Delay differential equations

    [WIP] Delay differential equations

    @thibmonsel

    This is a quick WIP draft of how we might add support for delay diffeqs into Diffrax.

    The goal is to make the API follow:

    def vector_field(t, y, args, *, history):
        ...
    
    delays = [lambda t, y, args: 1.0,
              lambda t, y, args: max(y, 1)]
    
    diffeqsolve(ODETerm(vector_field), ..., delays=delays)
    

    There's several pieces that still need doing:

    • The nonlinear solve, with respect to the dense solution over each step. (E.g. as per Section 4.1 of the DelayDiffEq.jl paper)
    • Detecting discontinuities and stepping to them directly. (Section 4.2)
    • Possibly add special support for "nice" delays, that we might be able to handle more efficiently? E.g. as long as our minimal delay is larger than our step size then the nonlinear solve can be skipped.
    • Adding documentation.
    • Adding an example.
    • Probably now would be a good time to figure out how to add support for solving DAEs as well (e.g. see #62). Both involve a nonlinear solve, and both involve passing extra information to the user-provided vector field. It might be that we can make use the same mechanisms for both. (And at the very least we should ensure that any choices we make now don't negatively impact DAE support later.)
    opened by patrick-kidger 24
  • Can't return solution of coupled differential equations

    Can't return solution of coupled differential equations

    I'm trying to solve a mid-sized system of coupled differential equations with diffrax. I'm using version 0.2.0. Here's a short snippet of dummy code that raises the issue I'm having:

    import jax.numpy as jnp
    from diffrax import diffeqsolve, ODETerm, Kvaerno3,PIDController
    
    def Results():
        def Y_prime(t, Y, args):
            dY = jnp.array([Y[6], (Y[5]-Y[6])**2,Y[0]+Y[7], (Y[1])**2, Y[2],Y[3], Y[4]**3, Y[5]**2])
            return dY
            
        t_init = 100
        t_fin = 1e5
    
        Yn_i = 1e-5
        Yp_i = 1e-6
        Yd_i = 1e-12
        Yt_i = 1e-12
        YHe3_i = 1e-12
        Ya_i = 1e-12
        YLi7_i = 1e-12
        YBe7_i = 1e-12
    
        Y0=jnp.array([[Yn_i], [Yp_i], [Yd_i], [Yt_i], [YHe3_i], [Ya_i], [YLi7_i], [YBe7_i]])
        term = ODETerm(Y_prime)
        solver = Kvaerno3()
        stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)
        t_eval = jnp.logspace(jnp.log10(t_init),jnp.log10(t_fin),num=100)
        sol_at_MT = diffeqsolve(term, solver, t0=jnp.float64(t_init), t1=jnp.float64(t_fin), dt0=jnp.float64((t_eval[1]-t_eval[0])/10),y0=Y0,stepsize_controller=stepsize_controller,max_steps=None)
        Yn_MT_f, Yp_MT_f, Yd_MT_f, Yt_MT_f, YHe3_MT_f, Ya_MT_f, YLi7_MT_f, YBe7_MT_f = sol_at_MT.ys[-1][0][0],sol_at_MT.ys[-1][1][0],sol_at_MT.ys[-1][2][0],sol_at_MT.ys[-1][3][0],sol_at_MT.ys[-1][4][0],sol_at_MT.ys[-1][5][0],sol_at_MT.ys[-1][6][0],sol_at_MT.ys[-1][7][0]
    
        Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Yn_MT_f, Yp_MT_f, Yd_MT_f,Yt_MT_f,YHe3_MT_f,Ya_MT_f,YLi7_MT_f, YBe7_MT_f
        return jnp.array([Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f])
    Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Results()
    print(Yn_f)
    

    It seems diffrax successfully solves the differential equation, but struggles to return the output, i.e. it seems the code hangs when trying to assign values to the variable sol_at_MT. Tampering a bit with the diffrax source, it looks like there are two things going on.

    One is that, no matter what I try to return (even if I set all of the returns to None), if the lines right before the return in integrate.py

    branched_error_if(
        throw & jnp.invert(is_okay(result)),
        error_index,
        RESULTS.reverse_lookup,
    )
    

    aren't commented out, the code will freeze. I can include a print statement right after these lines (just before the return) that prints out successfully even when they're not commented, but I can't assign anything to sol_at_MT in without the code hanging if these lines are left in.

    Then, if I comment that branched_error_if() call out, the code still hangs if I try to return ts, ys, stats or result from integrate.py. This doesn't seem to be an issue of time or memory; the code just freezes up and can't even be aborted from the command line whether I'm running locally or with extra resources on a cluster.

    question 
    opened by cgiovanetti 12
  • Handling discontinuities in time derivative?

    Handling discontinuities in time derivative?

    Hi, first of all, let me say that this looks like an amazing project. I am looking forward to playing around with this :).

    In a concrete problem I am dealing with, I have a forced system where the external force is piecewise constant. The external force changes at specific time points (t1, ..., tn), causing a discontinuity of the time derivative.
    I would like to use adaptive step-size solvers for increased accuracy, but naively applying adaptive step-size solvers will "waste" a lot of steps to find the point of change.

    Would including the change points in SaveAt avoid this problem? Or is there some other recommended way to handle this?

    opened by jaschau 12
  • Slow `jit` compilation time compared to `jax.experimental.ode.odeint`

    Slow `jit` compilation time compared to `jax.experimental.ode.odeint`

    hi @patrick-kidger, big fan of diffrax!

    I've been playing around with some of the functionality you have in this repository and comparing it with the ode-solver in jax. The one pain point i noticed is that there seems to be a relatively slow jit compilation time, particularly when I try to jit the grad of a simple loss function containing diffeqsolve. I was wondering if this is an error on my part (perhaps I botched the diffrax implementation) or if there is yet to be some optimization. The demonstration is below:

    from jax.config import config
    config.update("jax_enable_x64", True)
    config.update("jax_debug_nans", True) 
    config.parse_flags_with_absl()
    import jax
    import jax.numpy as jnp
    from jax import random
    import numpy as np
    from functools import partial
    import haiku as hk
    
    def exact_kinematic_aug_diff_f(t, y, args_tuple):
        """
        """
        _y, _, _ = y
        _params, _key, diff_f = args_tuple
        aug_diff_fn = lambda __y : diff_f(t, __y, (_params,))
        _f, s, t = aug_diff_fn(_y)
        r = jnp.sum(t)
        return _f, r, 0.
    
    def exact_kinematic_odeint_diff_f(y, t, params, canonical_diff_fn):
        run_y = y[0]
        _f, s, t = canonical_diff_fn(t, run_y, (params,))
        return _f, jnp.sum(s), 0.
    
    class TestMLP(hk.Module):
        def __init__(self, num_particles, name=None):
            super().__init__(name=None)
            self._mlp = hk.nets.MLP([8,8,8,8,num_particles*12])
            self._num_particles=num_particles
        def __call__(self, t, y):
            in_y = (y + t).flatten()
            outter = self._mlp(in_y).reshape((4, self._num_particles, 3))
            return outter[:2], outter[2], outter[3]
    
    def test(num_particles):
        import functools
        from jax.experimental.ode import odeint
        import diffrax
        
        #generate positions/velocities
        small_positions = jax.random.normal(jax.random.PRNGKey(261), shape=(num_particles,3))
        small_velocities = jax.random.normal(jax.random.PRNGKey(235), shape=(num_particles,3))
        small_positions_and_velocities = jnp.vstack([small_positions[jnp.newaxis, ...], small_velocities[jnp.newaxis, ...]])
        
        # make module kwargs
        VectorMLP_kwargs = {'num_particles': num_particles}
        
        # make module function
        def _diff_f_wrapper(t, y):
            diff_f = TestMLP(**VectorMLP_kwargs)
            return diff_f(t, y)
        
        diff_f_init, diff_f_apply = hk.without_apply_rng(hk.transform(_diff_f_wrapper))
        init_params = diff_f_init(jax.random.PRNGKey(36), 0., small_positions_and_velocities)
        canonicalized_diff_f_fn = lambda _t, _y, _args_tuple : diff_f_apply(_args_tuple[0], _t, _y)
        
        # make the augmented functions
        odeint_aug_diff_func = functools.partial(exact_kinematic_odeint_diff_f, canonical_diff_fn=canonicalized_diff_f_fn)
        diffeqsolve_aug_diff_func = exact_kinematic_aug_diff_f
        
        # odeint solver
        def odeint_solver(_parameters, _init_y, _key):
            aug_init_y = (_init_y, 0., 0.)
            outs = odeint(odeint_aug_diff_func, aug_init_y, jnp.array([0., 1.]), _parameters, rtol=1.4e-8, atol=1.4e-8)
            final_outs = (outs[0][-1], outs[1][-1], outs[2][-1])
            return final_outs
        
        def diffrax_ode_solver(_parameters, _init_y, _key):
            term=diffrax.ODETerm(diffeqsolve_aug_diff_func)
            stepsize_controller=diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
            solver = diffrax.Dopri5()
            aug_init_y = (_init_y, 0., 0.)
            sol = diffrax.diffeqsolve(term, 
                                      solver, 
                                      t0=0., 
                                      t1=1., 
                                      dt0=1e-1, 
                                      y0=aug_init_y, 
                                      stepsize_controller=stepsize_controller, 
                                      args=(_parameters, _key, canonicalized_diff_f_fn))
            return sol.ys[0][0], sol.ys[1][0], sol.ys[2][0]
        
        @jax.jit
        def odeint_loss_fn(_params, _init_y, _key):
            ode_solution = odeint_solver(_params, _init_y, _key)
            return jnp.sum(ode_solution[1]**2)
        
        @jax.jit
        def diffrax_loss_fn(_params, _init_y, _key):
            ode_solution = diffrax_ode_solver(_params, _init_y, _key)
            return jnp.sum((ode_solution[1])**2)
        
        # test
        import time
        
        # odeint compilation time
        start_time = time.time()
        _ = jax.grad(odeint_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
        end_time = time.time()
        print(f"odeint comp. time: {end_time - start_time}")
        
        # diffrax compilation time
        start_time = time.time()
        _ = jax.grad(diffrax_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
        end_time = time.time()
        print(f"diffrax comp. time: {end_time - start_time}")
    
    

    running test(8) gives me the following compilation time on CPU:

    odeint comp. time: 2.5580570697784424
    diffrax comp. time: 23.965799570083618
    

    I noticed that if I use diffrax.BacksolveAdjoint, compilation time goes down to ~8 seconds, but I'm keen to avoid that method based on your docs.; also, it looks like the compilation time in diffrax is heavily dependent on the number of hidden layers in TestMLP, perhaps suggesting a non-optimal compilation in diffrax of for loops? Thanks!

    refactor next 
    opened by dominicrufa 11
  • No GPU/TPU found, falling back to CPU

    No GPU/TPU found, falling back to CPU

    Here's the full warning that I get (I do have a GPU):

    >>> import diffrax
    2022-03-24 16:30:19.350737: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:170] XLA service 0x55795c0d4670 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
    2022-03-24 16:30:19.350761: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:178]   StreamExecutor device (0): Interpreter, <undefined>
    2022-03-24 16:30:19.353414: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:169] TfrtCpuClient created.
    2022-03-24 16:30:19.353886: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    

    Edit: I installed diffrax from conda-forge.

    opened by ma-sadeghi 11
  • Logging metrics during an ODE solve

    Logging metrics during an ODE solve

    Hello @patrick-kidger,

    thank you for open-sourcing this nice library! I was going to resume work on my own small ODE lib, but since this is much more elaborate than what I came up with so far, I am inclined to use this instead for a small project in the future.

    One question that came up to me when reading the source code: Is there currently a way to compute step-wise metrics during the solve? (Think logging step sizes, Jacobian eigenvalues, etc.)

    This would presumably happen in the integrate method. Could I e.g. use the solver_state pytree for this in, say, overridden solver classes? Thank you for your consideration.

    opened by nicholasjng 11
  • Brownian motion classes accept pytrees for shape and dtype arguments

    Brownian motion classes accept pytrees for shape and dtype arguments

    This changes the argument shape for classes VirtualBrownianTree and UnsafeBrownianPath, and adds an additional argument dtype as per the dicussion in #180.

    • I decided upon shape: Pytree[Tuple[int, ...] instead of shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]]. It's unclear what to do with named_shape in jax.ShapeDtypeStruct -- I don't know if there is a way to sample Brownian motion via named shapes. But if you feel strongly about this and give me some pointers, I can reimplement.
    • To allow specifying dtypes, dtype argument specifies them as a pytree and has to be a prefix tree of shape.
    • I added __init__ methods to both classes since I was not sure how to have dtype=None without it.
    • Added some helper functions that I use in misc.py, hope that's the right location to place them.
    • Used jtu.tree_map instead of jax.vmap -- was not sure how to supply is_leaf to jax.vmap. Happy to change this as well, with some pointers.
    • Changed the test_brownian.py:test_shape to test pytree shapes and dtypes. Just noticed that formatting made it look pretty bad, not sure if that's a big deal.
    • Tests pass locally.

    Let me know what you think. Thanks!

    opened by ciupakabra 9
  • added new kalman-filter example

    added new kalman-filter example

    I wrote a little additional example that showcases diffrax in a maybe not so obvious way. It also showcases equinox and the ability to freeze parameters. Let me know what you think (and what needs to be changed). Greetings

    opened by SimiPixel 8
  • Performance against `jax.experimental.ode.odeint`

    Performance against `jax.experimental.ode.odeint`

    Hi @patrick-kidger, I was excited to test out Diffrax in our code. However, we found it did not perform as well as expected. This is likely to nuances on our end, but because o https://github.com/google/jax/issues/9654 I thought I would post a MWE.

    import diffrax
    import jax
    import ticktack
    
    PARAMS = (774.86, 0.25, 0.8, 6.44)
    
    STEADY_PROD = 1.8803862513018528
    
    STEADY_STATE = jax.numpy.array(
        [1.34432991e+02, 7.07000000e+02, 1.18701144e+03,
        3.95666872e+00, 4.49574232e+04, 1.55056740e+02,
        6.32017337e+02, 4.22182768e+02, 1.80125397e+03,
        6.63307283e+02, 7.28080320e+03], 
        dtype=jax.numpy.float64)
    
    PROD_COEFFS = jax.numpy.array(
        [0.7, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 
        dtype=jax.numpy.float64)
    
    MATRIX = jax.numpy.array([
        [-0.509, 0.009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.508, -0.44, 0.068, 0.0, 0.0, 0.545, 0.0, 0.167, 0.002, 0.002, 0.0],
        [0.0, 0.121, -0.155, 12.0, 0.001, 0.0, 0.0, 0.003, 0.0, 0.0, 0.0],
        [0.0, 0.0, 4.4000e-02, -1.3333e+01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.042, 1.333, -0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.229, 0.0, 0.0, 0.0, -1.046, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.136, -0.033, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.364, 0.033, -0.183, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, -0.002, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.0, -0.002, 0.0],
        [0.0, 0.0, 3.333e-04, 0.0, 5.291e-06, 0.0, 0.0, 0.0, 0.0, 4.0e-04, -1.2340e-04]], 
        dtype=jax.numpy.float64)
    
    @jax.jit 
    def driving_term(t, args):
        start_time, duration, phase, area = jax.numpy.array(args)
        middle = start_time + duration / 2.
        height = area / duration
    
        gauss = height * \
            jax.numpy.exp(- ((t - middle) / (0.5 * duration)) ** 16.)
        sine = STEADY_PROD + 0.18 * STEADY_PROD *\
            jax.numpy.sin(2 * jax.numpy.pi / 11 * t + phase * 2 * jax.numpy.pi / 11)
    
        return (sine + gauss) * 3.747
    
    @jax.jit
    def jax_dydt(y, t, args, /, matrix=MATRIX, production=driving_term, 
                       prod_coeffs=PROD_COEFFS):
        ans = jax.numpy.matmul(matrix, y)
        production_rate_constant = production(t, args)
        production_term = prod_coeffs * production_rate_constant
        return ans + production_term
    
    @jax.jit
    def diffrax_dydt(t, y, args, /, matrix=MATRIX, production=driving_term, 
                     prod_coeffs=PROD_COEFFS):
        ans = jax.numpy.matmul(matrix, y)
        production_rate_constant = production(t, args)
        production_term = prod_coeffs * production_rate_constant
        return ans + production_term
    
    time_out = jax.numpy.linspace(750, 800, 1000)
    
    %%timeit
    jax.experimental.ode.odeint(jax_dydt, STEADY_STATE, time_out, PARAMS)
    
    term = diffrax.ODETerm(diffrax_dydt)
    solver = diffrax.Bosh3()
    step_size = diffrax.PIDController(rtol=1e-10, atol=1e-10)
    save_time = diffrax.SaveAt(ts=time_out)
    
    %%timeit
    diffrax.diffeqsolve(args=PARAMS, terms=term, solver=solver, y0=STEADY_STATE,
                        t0=time_out.min(), t1=time_out.max(), dt0=0.01,
                        saveat=save_time, stepsize_controller=step_size, 
                        max_steps=10000)
    

    Sorry that the example is so volumous but I wanted to keep it very similar to our code.

    Thanks in advance.

    Jordan

    opened by Jordan-Dennis 8
  • Weird behaviour due to defaults when using Implicit-Euler

    Weird behaviour due to defaults when using Implicit-Euler

    When using dfx.ImplicitEuler() with everything set to default an error is raised

    missing rtol and atol of NewtonNonlinearSolver

    You are then prompted to set these values in the stepsize-controller, because it is by default supposed to fallback to the values provided in PIDController. But dfx.ImplicitEuler() does not support adaptive step-sizing using a PIDController.

    The solution is to use

    solver=dfx.ImplicitEuler(nonlinear_solver=dfx.NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))
    

    Just something that feels a bit odd.

    refactor 
    opened by SimiPixel 6
  • Transform Feedforward-Network + solver into a Recurrent-Network

    Transform Feedforward-Network + solver into a Recurrent-Network

    Hello Patrick,

    let me first quickly motivate my feature request. As a side-project i am currently working on Model-based optimal control. For e.g. a only partially-observable environment stateful agents are useful. So, suppose the action selection of an agent is given by the following method

    def select_action(params, state, observation, time):
        apply = neural_network.apply
        state, action = apply(params, state, observation, time)
        return state, action
    
    while True:
        action = select_action(..., observation, env.time)
        observation = env.step(action)
    

    Typically, the apply-function is some recurrent neural network. Suppose the environment env is differentiable, because it is just some model of the environment (maybe another network). Now, i would like to replace the recurrent neural network with a feedforward network + solver without changing the API of the agent.

    I was wondering if constructing the following is possible and sensible? I.e. i would like to transform a choice of Feedforward-Network + Solver into a Recurrent-Network.

    def select_action(params, ode_state, observation, time):
        rhs = lambda x,u: neural_network.apply(params, x, u)
        solution, ode_state = odeint(ode_state, rhs, t1=time, u=(observation, time))
        return ode_state, solution.x(time)
    

    I would like to emphasis that this select_action must remain differentiable: The x-output w.r.t the network parameters.

    I would love to hear your input :) Anyways thank you in advance.

    opened by SimiPixel 5
  • ODE solver fails with 'The maximum number of solver steps was reached. Try increasing `max_steps`'

    ODE solver fails with 'The maximum number of solver steps was reached. Try increasing `max_steps`'

    Hi,

    I was playing with this cool package on a chemical reaction ODE problem. This problem solves the time evolution of seven chemical concentrations, which is a stiff problem but can be solved using a Fortran-based solver. However, the diffrax version fails, with an XlaRuntimeError complaining 'The maximum number of solver steps was reached. Try increasing max_steps'. Unfortunately, the error persists no matter how large the max_steps is and which solver is used (e.g., impliciteuler or Kvaerno5). Note that when commenting the error message in diffeqsolve function, I find that the code can solve about the first 100s and output inf (from solution.ys) in a later time.

    Any suggestion would be appreciated!

    Below is the code snippet --

    from diffrax import diffeqsolve, ODETerm, SaveAt
    from diffrax import NewtonNonlinearSolver, Dopri5, Kvaerno3, ImplicitEuler, Euler, Kvaerno5
    from diffrax import PIDController
    
    import jax
    import jax.numpy as jnp
    import jax.random as jrandom
    
    from jax.config import config
    config.update("jax_enable_x64", True)
    
    def funclog2(t, logy, args):
        k1, k2, k3 = args[0], args[1], args[2]
        kd1, kd2, kd3 = args[3], args[4], args[5]
        ka1, ka2, ka3 = args[6], args[7], args[8]
        r4 = args[9]
        
        y = jnp.power(10, logy)
        doc, o2, no3, no2, n2, co2, bm = y
        
        # log transform scale
        scale = 1 / jnp.log(10)
        scale = scale / y
        
        # The stoichiometry matrix
        stoich = jnp.array([
            [-1, -1, -1, 5],
            [0, 0, -1, 0],
            [-2, 0, 0, 0],
            [1, -1, 0, 0],
            [0, 1, 0, 0],
            [1, 1, 1, 0],
            [0, 0, 0, -1]
        ])
        
        # Scale stoich
        stoich = jax.vmap(lambda a, b: a*b, in_axes=0)(scale, stoich)
        
        # Reaction rate
        r1 = k1 * bm * doc/(doc+kd1) * no3/(no3+ka1)
        r2 = k2 * bm * doc/(doc+kd2) * no2/(no2+ka2)
        r3 = k3 * bm * doc/(doc+kd3) * o2/(no2+ka3)
        
        r = jnp.array([r1, r2, r3, r4]).T
        
        return stoich @ r
    
    # Static parameters
    k1, k2, k3 = 3.24e-4, 2.69e-4, 9e-4 # [mol/L/sec/mass [BM]]
    kd1, kd2, kd3 = 2.5e-4, 2.5e-4, 2.5e-4 # [mol/L]
    ka1, ka2, ka3 = 1e-6, 4e-6, 1e-6  # [mol/L]
    r4 = 2.8e-6 # [mol/L/sec]
    args = jnp.array([k1, k2, k3, kd1, kd2, kd3, ka1, ka2, ka3, r4])
    
    # The initial concentrations with the following order [mol/L]:
    # doc, o2, no3, no2, n2, co2, bm
    # y0 = jnp.array([4.16e-05, 0.000266, 0.000396, 1e-10, 1e-10, 0.00248, 0.0003])
    y0 = jnp.array([4.16e-05, 0.000266, 0.000396, 1e-3, 1e-3, 0.00248, 0.0003])
    logy0 = jnp.log10(y0)
    
    term = ODETerm(funclog2)
    # solver = Dopri5()
    # solver = Euler()
    solver = Kvaerno5(NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))
    # solver = ImplicitEuler(NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))
    # t0, t1, dt0 = 0, 3600*24*30, 1
    t0, t1, dt0 = 0, 200, 0.01
    # t0, t1, dt0 = 0, 3600*24, 3600
    solution = diffeqsolve(term, solver, t0=t0, t1=t1, dt0=dt0, max_steps=400000,
                           stepsize_controller=PIDController(rtol=1e-3, atol=1e-6),
                           saveat = SaveAt(t0=True, ts=jnp.linspace(t0,t1)), 
                           y0=logy0, args=args)
    solution.stats
    
    question 
    opened by PeishiJiang 5
  • Truncated Back Propagation through time

    Truncated Back Propagation through time

    Hi, I was wondering if it possible to integrate truncated back propagation through time (TBPTT) into Diffrax. I couldn't find any options for this in Diffrax or Equinox, nor could I find any implementation of TBPTT in the source code in integrate.py, but maybe I missed it. My best guess would be to write a custom adjoint class that would implement TBPTT, but I am not sure how to do this. My question is: would it be possible to (easily) implement TBPTT to train my NDEs and how should I approach this?

    feature 
    opened by sdevries0 1
  • Fastest way to evaluate a solution

    Fastest way to evaluate a solution

    Hi, suppose I have a simple ODE that I solve with diffrax. What would be the fastest way to use the solution in another piece of code? I need to evaluate the solution on some points not known in advance, and I thought of generating a dense solution sol and then use its method evaluate on the points of interest, i.e. every time I need it, call sol.evaluate() on my points of interest (using vmap when needed). Is this the most efficient way, or shall I interpolate myself a fixed grid solution and create a jitted function that evaluates it on my points of interest?

    question 
    opened by marcofrancis 1
  • Make diffeqsolve convertable to TensorFlow

    Make diffeqsolve convertable to TensorFlow

    Based on a talk on NODE's on youtube I came across this package, and this looks perfect for some project we are planning (thanks for the great talk!) . Now one of the platforms where we want to run our code does not support JAX/XLA/Tensorflow. Just ONNX. I tried converting a simulation function to Tensorflow for later conversion to ONNX, but this fails because the unsupported unvmap_any is used (at compiletime!) to deduce the amount of iterations needed.

    Minimal example:

    import tensorflow as tf
    import jax.numpy as jnp
    import tf2onnx
    
    from diffrax import diffeqsolve, ODETerm, Euler
    from jax.experimental import jax2tf
    
    def simulate(y0):
        solution = diffeqsolve(
                terms=ODETerm(lambda t, y, a: -y), solver=Euler(),
                t0=0, t1=1, dt0=0.1, y0=y0)
        return solution.ys[0]
    
    # This works
    x = simulate(100)
    assert jnp.isclose(x, jnp.exp(-1)*100, atol=.1, rtol=.1)
    
    simulate_tf = tf.function(jax2tf.convert(simulate, enable_xla=False))
    
    # Does not work:
    # simulate_tf(100)
    # => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented
    
    # Also doesn't not work:
    tf2onnx.convert.from_function(
            simulate_tf, input_signature=[tf.TensorSpec((), tf.float32)])
    # simulate_tf(100)
    # => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented
    

    For us, it would be really nice to use a GPU/TSP during training with jax, then transfer to this specifc piece of hardware with just ONNX support for inference (at this point I don't need gradient calculation anymore). Of course, solving this might be completely outside the scope of the project and there are other solutions like writing the solvers from scratch or using existing solvers in TF/PyTorch.

    Currently my knowledge of JAX is limited (hopefully this will soon improve!). If this is the only function stopping Diffrax from being tensorflow-convertable maybe a small workaround could be possible. I'm also happy with a answer like 'no we don't do this' or 'send us a PR if you want to have this fixed'

    feature 
    opened by llandsmeer 6
  • Question about BacksolveAdjoint through SemiImplicitEuler solver

    Question about BacksolveAdjoint through SemiImplicitEuler solver

    I am testing the adjoint method to calculate the gradients from a SemiImplicitEuler solver. I met errors when calculate the gradients using BacksolveAdjoint method. Here is a working example. It would be great to have some suggestions.

    Thank you in advance!

    ` from diffrax import diffeqsolve, ODETerm, SemiImplicitEuler, SaveAt, BacksolveAdjoint import jax.numpy as jnp from jax import grad from matplotlib import pyplot as plt

    def drdt(t, v, args): return v

    def dvdt(t, r, args): return -args[0]*(r-args[1])

    terms =(ODETerm(drdt),ODETerm(dvdt)) solver = SemiImplicitEuler() y0 = (jnp.array([1.0]),jnp.array([0.0])) saveat = SaveAt(ts=jnp.arange(0,30,0.1))

    def loss(y0): solution = diffeqsolve(terms, solver, t0=0, t1=30, dt0=0.0001, y0=y0, args=[1.0,0.0], saveat=saveat,max_steps=10000000,adjoint=BacksolveAdjoint()) return jnp.sum(solution.ys[0]) grads = grad(loss)(y0) print(grads) `

    here is the error message:

    Traceback (most recent call last): File "test_harmonic.py", line 23, in <module> grads = grad(loss)(y0) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 482, in fn_bwd_wrapped out = fn_bwd(residuals, grad_diff_array_out, vjp_arg, *args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 394, in _loop_backsolve_bwd state, _ = _scan_fun(state, val0, first=True) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 332, in _scan_fun _sol = diffeqsolve( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 82, in __call__ return __self._fun_wrapper(False, args, kwargs) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 78, in _fun_wrapper dynamic_out, static_out = self._cached(dynamic, static) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 30, in fun_wrapped out = fun(*args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 858, in diffeqsolve final_state, aux_stats = adjoint.loop( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 499, in loop final_state, aux_stats = _loop_backsolve( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 509, in __call__ out = self.fn_wrapped( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 443, in fn_wrapped out = self.fn(vjp_arg, *args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 250, in _loop_backsolve return self._loop_fn( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 497, in loop final_state = bounded_while_loop( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/misc/bounded_while_loop.py", line 125, in bounded_while_loop return lax.while_loop(cond_fun, _body_fun, init_val) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/misc/bounded_while_loop.py", line 118, in _body_fun _new_val = body_fun(_val, inplace) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 137, in body_fun (y, y_error, dense_info, solver_state, solver_result) = solver.step( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/solver/semi_implicit_euler.py", line 42, in step y0_1, y0_2 = y0 ValueError: too many values to unpack (expected 2)

    bug feature 
    opened by Chenghao-Wu 1
Releases(v0.2.2)
  • v0.2.2(Nov 15, 2022)

    Performance improvements

    • Now make fewer vector field traces in several cases (#172, #174)

    Fixes

    • Many documentation improvements.
    • Fixed several warnings about jax.{tree_map,tree_leaves,...} being moved to jax.tree_util.{tree_map,tree_leaves,...}. (Thanks @jacobusmmsmit!)
    • Fixed the step size controller choking if the error is ever NaN. (#143, #152)
    • Fixed some crashes due to JAX-internal changes (If you've ever seen it throw an error about not knowing how to rewrite closed_call_p, it's this one.)
    • Fixed an obscure edge-case NaN on the backward pass, if you were using an implicit solver with an adaptive step size controller, got a rejected step due to the implicit solve failing to converge, and happened to also be backpropagating wrt the controller_state.

    Other

    • Added a new Kalman filter example (#159) (Thanks @SimiPixel!)
    • Brownian motion classes accept pytrees for shape and dtype arguments (#183) (Thanks @ciupakabra!)
    • The main change is an internal refactor: a lot of functionality has moved diffrax.misc -> equinox.internal.

    New Contributors

    • @jacobusmmsmit made their first contribution in https://github.com/patrick-kidger/diffrax/pull/149
    • @SimiPixel made their first contribution in https://github.com/patrick-kidger/diffrax/pull/159
    • @ciupakabra made their first contribution in https://github.com/patrick-kidger/diffrax/pull/183

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.2.1...v0.2.2

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Aug 3, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Made is_okay,is_successful,is_event public by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/134
    • Fix implicit adjoints assuming array-valued state by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/136
    • Replace jax tree manipulation method that are being deprecated with jax.tree_util equivalents by @mahdi-shafiei in https://github.com/patrick-kidger/diffrax/pull/138
    • bump version by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/141

    New Contributors

    • @mahdi-shafiei made their first contribution in https://github.com/patrick-kidger/diffrax/pull/138

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Jul 20, 2022)

    • Feature: event handling. In particular it is now possible to interrupt a diffeqsolve early. See the events page in the docs and the new steady state example.
    • Compilation time improvements:
      • The compilation speed of NewtonNonlinearSolver (and thus in practice also all implicit solvers like Kvaerno3 etc.) has been improved (~factor 1.5)
      • The compilation speed of all Runge--Kutta solvers can be dramatically reduced (~factor 3) by passing e.g. Dopri5(scan_stages=True). This may increase runtime slightly. At the moment the default is scan_stages=False for all solvers, but this default might change in the future.
    • Various documentation improvements.

    New Contributors

    • @jatentaki made their first contribution in https://github.com/patrick-kidger/diffrax/pull/121

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.1.2...v0.2.0

    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(May 18, 2022)

    Main change here is a minor technical one - Diffrax will no longer initialise the JAX backend as a side effect of being imported.


    Autogenerated release notes as follows:

    What's Changed

    • Removed explicit jaxlib dependency by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/93
    • switch error_if to python if (regarding google/jax/issues/10047) by @amir-saadat in https://github.com/patrick-kidger/diffrax/pull/99
    • Doc fixes by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/100
    • Bump version by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/107

    New Contributors

    • @amir-saadat made their first contribution in https://github.com/patrick-kidger/diffrax/pull/99

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.1.1...v0.1.2

    Source code(tar.gz)
    Source code(zip)
  • v0.1.1(Apr 7, 2022)

    Diffrax uses some JAX-internal functionality that will shortly be deprecated in JAX. This release adds the appropriate support for both older and newer versions of JAX.


    Autogenerated release notes as follows:

    What's Changed

    • [JAX] Add MHLO lowerings in preparation for xla.lower_fun() removal by @hawkinsp in https://github.com/patrick-kidger/diffrax/pull/91
    • Bump version by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/92

    New Contributors

    • @hawkinsp made their first contribution in https://github.com/patrick-kidger/diffrax/pull/91

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.1.0...v0.1.1

    Source code(tar.gz)
    Source code(zip)
  • v0.1.0(Mar 30, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Adjusted PIDController by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/89

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.6...v0.1.0

    Source code(tar.gz)
    Source code(zip)
  • v0.0.6(Mar 29, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Symbolic regression text by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/79
    • Fixed edge case infinite loop on stiff-ish problems (+very bad luck) by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/86

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.5...v0.0.6

    Source code(tar.gz)
    Source code(zip)
  • v0.0.5(Mar 21, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Doc tweaks by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/72
    • Added JIT wrapper to stiff ODE example by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/75
    • Added autoreleases by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/78
    • Removed overheads from runtime checking when they can be compiled away. by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/77

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.4...v0.0.5

    Source code(tar.gz)
    Source code(zip)
  • v0.0.4(Mar 6, 2022)

    First release using GitHub releases! We'll be using this to serve as a changelog.

    As for what has changed since the v0.0.3 release, we'll let the autogenerated release notes do the talking:

    What's Changed

    • Rewrote RK implementation quite substantially to allow FSAL RK SDE integrators. by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/70

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.3...v0.0.4

    Source code(tar.gz)
    Source code(zip)
Owner
Patrick Kidger
Maths+ML PhD student at Oxford. Neural ODEs+SDEs+CDEs, time series, rough analysis. (Also ice skating, martial arts and scuba diving!)
Patrick Kidger
PyTorch implementation of ShapeConv: Shape-aware Convolutional Layer for RGB-D Indoor Semantic Segmentation.

Shape-aware Convolutional Layer (ShapeConv) PyTorch implementation of ShapeConv: Shape-aware Convolutional Layer for RGB-D Indoor Semantic Segmentatio

Hanchao Leng 82 Dec 29, 2022
Local Similarity Pattern and Cost Self-Reassembling for Deep Stereo Matching Networks

Local Similarity Pattern and Cost Self-Reassembling for Deep Stereo Matching Networks Contributions A novel pairwise feature LSP to extract structural

31 Dec 06, 2022
Related resources for our EMNLP 2021 paper

Plan-then-Generate: Controlled Data-to-Text Generation via Planning Authors: Yixuan Su, David Vandyke, Sihui Wang, Yimai Fang, and Nigel Collier Code

Yixuan Su 61 Jan 03, 2023
Learning Neural Painters Fast! using PyTorch and Fast.ai

The Joy of Neural Painting Learning Neural Painters Fast! using PyTorch and Fast.ai Blogpost with more details: The Joy of Neural Painting The impleme

Libre AI 72 Nov 10, 2022
A library for performing coverage guided fuzzing of neural networks

TensorFuzz: Coverage Guided Fuzzing for Neural Networks This repository contains a library for performing coverage guided fuzzing of neural networks,

Brain Research 195 Dec 28, 2022
From the basics to slightly more interesting applications of Tensorflow

TensorFlow Tutorials You can find python source code under the python directory, and associated notebooks under notebooks. Source code Description 1 b

Parag K Mital 5.6k Jan 09, 2023
Official PyTorch repo for JoJoGAN: One Shot Face Stylization

JoJoGAN: One Shot Face Stylization This is the PyTorch implementation of JoJoGAN: One Shot Face Stylization. Abstract: While there have been recent ad

1.3k Dec 29, 2022
Underwater image enhancement

LANet Our work proposes an adaptive learning attention network (LANet) to solve the problem of color casts and low illumination in underwater images.

LiuShiBen 7 Sep 14, 2022
A set of examples around hub for creating and processing datasets

Examples for Hub - Dataset Format for AI A repository showcasing examples of using Hub Uploading Dataset Places365 Colab Tutorials Notebook Link Getti

Activeloop 11 Dec 14, 2022
Code and training data for our ECCV 2016 paper on Unsupervised Learning

Shuffle and Learn (Shuffle Tuple) Created by Ishan Misra Based on the ECCV 2016 Paper - "Shuffle and Learn: Unsupervised Learning using Temporal Order

Ishan Misra 44 Dec 08, 2021
Open-source implementation of Google Vizier for hyper parameters tuning

Advisor Introduction Advisor is the hyper parameters tuning system for black box optimization. It is the open-source implementation of Google Vizier w

tobe 1.5k Jan 04, 2023
This is my research project for the Irving Center for Cancer Dynamics/Azizi Lab, Columbia University.

bayesian_uncertainty This is my research project for the Irving Center for Cancer Dynamics/Azizi Lab, Columbia University. In this project I build a s

Max David Gupta 1 Feb 13, 2022
I created My own Virtual Artificial Intelligence named genesis, He can assist with my Tasks and also perform some analysis,,

Virtual-Artificial-Intelligence-genesis- I created My own Virtual Artificial Intelligence named genesis, He can assist with my Tasks and also perform

AKASH M 1 Nov 05, 2021
Code for STFT Transformer used in BirdCLEF 2021 competition.

STFT_Transformer Code for STFT Transformer used in BirdCLEF 2021 competition. The STFT Transformer is a new way to use Transformers similar to Vision

Jean-François Puget 69 Sep 29, 2022
A torch.Tensor-like DataFrame library supporting multiple execution runtimes and Arrow as a common memory format

TorchArrow (Warning: Unstable Prototype) This is a prototype library currently under heavy development. It does not currently have stable releases, an

Facebook Research 536 Jan 06, 2023
A framework for Quantification written in Python

QuaPy QuaPy is an open source framework for quantification (a.k.a. supervised prevalence estimation, or learning to quantify) written in Python. QuaPy

41 Dec 14, 2022
PyTorch implementation of DCT fast weight RNNs

DCT based fast weights This repository contains the official code for the paper: Training and Generating Neural Networks in Compressed Weight Space. T

Kazuki Irie 4 Dec 24, 2022
[2021][ICCV][FSNet] Full-Duplex Strategy for Video Object Segmentation

Full-Duplex Strategy for Video Object Segmentation (ICCV, 2021) Authors: Ge-Peng Ji, Keren Fu, Zhe Wu, Deng-Ping Fan*, Jianbing Shen, & Ling Shao This

Daniel-Ji 55 Dec 22, 2022
CPPE - 5 (Medical Personal Protective Equipment) is a new challenging object detection dataset

CPPE - 5 CPPE - 5 (Medical Personal Protective Equipment) is a new challenging dataset with the goal to allow the study of subordinate categorization

Rishit Dagli 53 Dec 17, 2022
This repository is maintained for the scientific paper tittled " Study of keyword extraction techniques for Electric Double Layer Capacitor domain using text similarity indexes: An experimental analysis "

kwd-extraction-study This repository is maintained for the scientific paper tittled " Study of keyword extraction techniques for Electric Double Layer

ping 543f 1 Dec 05, 2022