Hardware accelerated, batchable and differentiable optimizers in JAX.

Overview

JAXopt

Installation | Examples | References

Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

Installation

JAXopt can be installed with pip directly from github, with the following command:

$ pip install git+https://github.com/google/jaxopt

Alternatively, it can be be installed from sources with the following command:

$ python setup.py install

References

Our implicit differentiation framework is described in this paper. To cite it:

@article{jaxopt_implicit_diff,
  title={Efficient and Modular Implicit Differentiation},
  author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe},
  journal={arXiv preprint arXiv:2105.15183},
  year={2021}
}

Disclaimer

JAXopt is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

Comments
  • Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat optimizer runs exceptionally slow (~30 seconds for 30 iterations) until I turn on verbose==True (~1 second for 30 iterations. Any idea what may be going on? enabling JIT seems to have no impact. Was hoping to use this for a real-time system but even at 1 second things are way too slow.

    opened by pablovela5620 23
  • implementation of Fletcher-Reeves Algorithm

    implementation of Fletcher-Reeves Algorithm

    • Polak-Ribiere Method; To my knowledge, it was quite successful to use conjugate gradient variants on general nonconstrained optimization

    This PR depends on Line Search of PR #128.

    • Beta division is required to guarantee strong Wolfe Condition, but (i don't know) it raises error..
    pull ready 
    opened by ita9naiwa 17
  • vmap support in QPs

    vmap support in QPs

    Hi, I experience some pb with projection_polyhedron

    import numpy as np
    import matplotlib.pyplot as plt
    
    import jax
    import jax.numpy as jnp
    
    import jaxopt
    from jaxopt.projection import projection_l2_ball, projection_box, projection_l1_ball, projection_polyhedron
    
    def myproj3(x):
        A = jnp.array([[1.0, 1.0]])
        b = jnp.array([1.0])
        G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
        h = jnp.array([0.0, 0.0])    
        x = projection_polyhedron(x,hyperparams = (A, b, G, h))
        return x
    
    rng_key = jax.random.PRNGKey(42)
    x = jax.random.uniform(rng_key, (5000,2), minval=-3,maxval=3)
    p1_x=jax.vmap(myproj3, in_axes=(0,None))(x)
    fig, ax = plt.subplots(figsize=(5,5))
    ax.scatter(x[:,0],x[:,1],s=0.5)
    ax.scatter(p1_x[:,0],p1_x[:,1],s=0.5,c='g')
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    plot.show()
    

    First, I had to install cvxpy #!pip install cvxpy Then, I got this error

    TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[2])>with<BatchTrace(level=1/1)>
      with val = DeviceArray([[-2.37103211,  2.33759997],
                              [ 2.76953806, -2.37750394],
                              [-0.87246632,  0.73224625],
                              ...,
                              [ 2.29799773,  2.81894884],
                              [ 2.4022714 ,  0.80693103],
                              [-0.41563116,  2.83898531]], dtype=float64)
           batch_dim = 0
    

    Is anyone has an hint? Thanks

    enhancement 
    opened by jecampagne 12
  • KKT conditions when the primal solution is a pytree

    KKT conditions when the primal solution is a pytree

    Hi, Congrats on the great tool! Inspired by the QuadraticProgramming example I built a code that differentiates through KKT conditions. My code works whenever the primal solution variable is a jnp array, but not when it's a generic pytree. Giving me the following issue:

    TypeError: Tree structure of cotangent input PyTreeDef(([(*, *), (), (*, *)], *, None)), does not match structure of primal output PyTreeDef(([(*, *), (), (*, *), (*, *), (), (*, *)], *, None))

    where I'm pretty sure [(*, *), (), (*, *)] represents the primal solution and PyTreeDef(([(*, *), (), (*, *)], *, None)) could represent the optimality function.

    I was able to make it work by storing the primal solution in a single jnp array and reshaping it into the appropriate pytree whenever needed, but it's not clean or efficient. I was wondering if there's a bug in the current codebase (I only found tests for single jnp arrays) or I'm misusing the interface (I'm not a jax expert).

    To make it easier to reproduce I modified the quadratic_prog.py file by making the model return a list of one array instead of an array for the primal variables (leaving both dual variables the same). Then I modified the obj_fun, eq_fun and ineq_fun to use primal_var[0] instead of primal_var. If I understand correctly, this should still work. However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is: ([DeviceArray([ 0.43999994, -1.3199999 ], dtype=float32), DeviceArray([-0.44000003, 1.32 ], dtype=float32)], DeviceArray([2.9802322e-08], dtype=float32), None)

    Looking at the numbers of the problem I believe [0.44,-1.32] is the gradient of the obj_fun w.r.t. the primal and [-0.44,+1.32] the gradient of the equality constraint w.r.t. the primal times the dual. They should have been added up together to have [0,0] as expected. I feel this may be fundamentally the same problem I was facing in my own research code since there I also found one of the values had the shape of the primal variable twice instead of once.

    Notice also thatthe test on the line just above (checking that the primal solution is correct) still holds provided we check sol[0][0] instead of sol[0] (since sol[0] is now a 1-element list).

    Is differentiation through KKT supposed to work for general pytrees? If so, what should I have done to make it work in the quadratic_prog.py example?

    Thanks!

    opened by FerranAlet 11
  • Hot fix: corrected condition in lbfgs

    Hot fix: corrected condition in lbfgs

    The feature I had introduced in https://github.com/google/jaxopt/pull/323 was failing when the run function was jitted and was a no-op when not because of the following reason:

     ~True == -2  # this is True
    

    Therefore when jitted it was complaining about different types in a condition function, and when not jitted it was equivalent to always being False.

    EDIT

    Actually I am still running into an error when jitted, so will continue to investigate.

    The gist of the error is Abstract tracer value encountered where concrete value is expected, basically doing (not self.stop_if_linesearch_fails | ~state.failed_linesearch) is not allowed because one is a bool and the other is an abstract value.

    pull ready 
    opened by zaccharieramzi 7
  • Issue with gradients wrt optimality fn parameters through root finding vjp

    Issue with gradients wrt optimality fn parameters through root finding vjp

    First of all, thanks a lot for this library! Really useful tools! I'm interested in getting at least 2nd order gradients through root finding, and I'm finding an odd behavior that I wanted to report.

    Maybe I'm doing something wrong, but in the following schematic case I silently get the wrong gradients:

    def inv_f(x, aux):
      bisec = Bisection(optimality_fun=F, lower=0.0, upper=1., 
                        check_bracket=False, unroll=True)
      return bisec.run(aux=aux).params
    
    # Here I extract the value part of the vjp, but the grad part also gives wrong results
    test_fn = lambda aux: jax.value_and_grad(inv_f)(0.5, aux)[0] 
    
    jax.grad(test_fn)(1.) # Returns 0 instead of the expected gradients
    

    Here I'm only trying to get gradients of the value returned by jax.value_and_grad, but the gradients of the gradients returned by jax.value_and_grad are also wrong (but not as obvious).

    I made a small demo notebook that reproduces this issue here.

    As a reference I've also implemented my own implicit gradients, bypassing the jaxopt ones, and they seem to give me the correct answer.

    Reading the source code of jaxopt, it is not immediatly obvious to me why this doesn't work... Sorry I couldn't directly suggest a PR, but I hope this report is still useful (and that I'm not just using jaxopt wrong).

    bug 
    opened by EiffL 7
  • misc improvements to robust training example

    misc improvements to robust training example

    main changes:

    • Fixes #134 by normalizing in-place.
    • Plot convergence curves for both clean and adversarial accuracy.
    • Replace the fast-sign-gradient method by the much more powerful PGD method.
    • Be able to select different datasets.
    • Homogeneize API wrt to the other examples. For example, this now uses the same load_dataset, CNN, loss_fun, accuracy than flax_image_classif.py . Most of the command line flags have also been homogeneized.
    pull ready 
    opened by fabianp 7
  • Bisection hanging

    Bisection hanging

    I am trying to jaxopt.Bisection to replace the use of scipy.optimize.bisect in a computational model but Bisection hangs when I run my code.

    The basic structure includes 2 functions that are both jitted (so I assume it should be able to compile ok):

    @jit
    def f1(parameters):
        ....
        return jax.numpy.array([a,b,c])
    
    @jit
    def opt_fun(x):
        f1(x,params)
        .... 
        return float_value
    

    when I call scipy.optimize.bisect(opt_fun,x0,x1) it runs with no issue but jaxopt.Bisection(opt_fun,x0,x1).run(None) hangs with with~10% cpu usage and55% memory usage on i9 2018 macbook pro with 32GB of memory.

    I acknowledge I may be using this incorrectly and that this is possibly not the intended use case but any direction would be very helpful. My intention is to use this computational model with numpyro in the future and having a jax version of the bisection root finding would be incredibly helpful.

    opened by jjruby09 7
  • Incompatible shape in solve_normal_cg

    Incompatible shape in solve_normal_cg

    When A.shape = (N, P) for N != P, I run into shape errors when trying to use solve_normal_cg for fitting the normal equations.

    I have a small reproducible example below for N > P, but the error holds for when P > N.

    import jax.numpy as jnp
    import numpy as np
    N = 1000
    P = 3
    prob = np.random.uniform(0.01, 0.5, size=P)
    h2g = 0.1
    X = np.random.binomial(2, p=prob, size=(N, P))
    b = np.random.normal(size=(P)) * np.sqrt(h2g / P)
    y = X @ b + np.sqrt(1 - h2g) * np.random.normal(size=(N,))
    
    import jaxopt as jopt
    jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    Input In [11], in <module>
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:151, in solve_normal_cg(matvec, b, ridge, init, **kwargs)
        148 if ridge is not None:
        149   _matvec = _make_ridge_matvec(_matvec, ridge=ridge)
    --> 151 Ab = _rmatvec(matvec, b)
        153 return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:114, in _rmatvec(matvec, x)
        112 def _rmatvec(matvec, x):
        113   """Computes A^T x, from matvec(x) = A x, where A is square."""
    --> 114   transpose = jax.linear_transpose(matvec, x)
        115   return transpose(x)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/api.py:2211, in linear_transpose(fun, reduce_axes, *primals)
       2208 in_dtypes = map(dtypes.dtype, in_avals)
       2210 in_pvals = map(pe.PartialVal.unknown, in_avals)
    -> 2211 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
       2212                                              instantiate=True)
       2213 out_avals, _ = unzip2(out_pvals)
       2214 out_dtypes = map(dtypes.dtype, out_avals)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:505, in trace_to_jaxpr(fun, pvals, instantiate)
        503 with core.new_main(JaxprTrace) as main:
        504   fun = trace_to_subjaxpr(fun, main, instantiate)
    --> 505   jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
        506   assert not env
        507   del main, fun, env
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
        163 gen = gen_static_args = out_store = None
        165 try:
    --> 166   ans = self.f(*args, **dict(self.params, **kwargs))
        167 except:
        168   # Some transformations yield from inside context managers, so we have to
        169   # interrupt them before reraising the exception. Otherwise they will only
        170   # get garbage-collected at some later time, running their cleanup tasks only
        171   # after this exception is handled, which can corrupt the global state.
        172   while stack:
    
    Input In [11], in <lambda>(x)
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4196, in dot(a, b, precision)
       4194   return lax.mul(a, b)
       4195 if _max(a_ndim, b_ndim) <= 2:
    -> 4196   return lax.dot(a, b, precision=precision)
       4198 if b_ndim == 1:
       4199   contract_dims = ((a_ndim - 1,), (0,))
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/lax/lax.py:667, in dot(lhs, rhs, precision, preferred_element_type)
        664   return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
        665                      precision=precision, preferred_element_type=preferred_element_type)
        666 else:
    --> 667   raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
        668       lhs.shape, rhs.shape))
    
    TypeError: Incompatible shapes for dot: got (1000, 3) and (1000,).
    
    opened by quattro 6
  • Initial stepsize not exposed in LBFGS constructor [question/bug?]

    Initial stepsize not exposed in LBFGS constructor [question/bug?]

    I see that LbfgsState contains a stepsize and that LBFGS.init_state hard-codes it to 1. I also see that the LBFGS.update method performs a line search in which the initial step size is set from this LBFGS state.

    I have a particularly ill-conditioned problem that requires tiny initial steps, but I was surprised that the initial stepsize could not be set in the LBFGS constructor or elsewhere as far as I could see. Is this an oversight or an intentional part of the design? If it's intentional, is there an idiomatic way to set an initial stepsize when using LBFGS.run that I have overlooked?

    Thanks in advance, and thanks for a really cool library.

    opened by erdmann 6
  • Infinities and NaNs in quadratic_prog when c=0

    Infinities and NaNs in quadratic_prog when c=0

    Hi,

    I'm using QuadraticProgramming in the special case of c=0 (all zeros as a vector). AFAIK this is still well-defined, as it's just minimizing l2 norm squared of the primal subject to some equality constraints (I don't have inequalities).

    However, both my research code and the following modification of this test diverge even for a single step (maxiter=1).

    The modification just involves setting c=0, so:

    def test_qp_eq_only_c_zero(self):
      Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
      c = jnp.array([0.0, 0.0]) #ONLY CHANGE
      A = jnp.array([[1.0, 1.0]])
      b = jnp.array([1.0])
      qp = QuadraticProgramming(tol=1e-7)
      hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
      sol = qp.run(**hyperparams).params
      self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
      self._check_derivative_A_and_b(qp, hyperparams, A, b)
    

    Is there a way to fix it? If it involves calling another linear solver, is there a way to specify the solver from the high-level QP function? I haven't seen it.

    Thanks!

    opened by FerranAlet 6
  • OptaxSolver Error: too many positional arguments

    OptaxSolver Error: too many positional arguments

    Hello! I tried to implement the example of implicit differentiation as shown here but with my own functions. The task is to find mean for a set of vectors named X via gradient descent.

    import numpy as np
    import matplotlib.pyplot as plt
    from tqdm import tqdm
    
    import jax
    import jax.numpy as jnp
    from jax import grad, random, jit
    from jax import jacobian, hessian, jacfwd, jacrev
    key = random.PRNGKey(0)
    
    import jaxopt
    from jaxopt import implicit_diff
    from jaxopt import linear_solve
    from jaxopt import OptaxSolver, GradientDescent
    import optax
    
    def euclidean_distance(a, b):
        """
        Squared Euclidean distance
        """
        return jnp.inner(a - b, a - b)
    
    def weighted_distance(x, X, w):
        loss = 0
        for i, obj in enumerate(X):
            loss += w[i] * euclidean_distance(obj, x)
        return loss
    
    def identical(Y, Y_grad):
        return Y
    

    Algorithm for finding mean:

    # Mean calculation for manifolds with gradient descent
    @implicit_diff.custom_root(jax.grad(weighted_distance))
    def euclidean_weighted_mean(X_set, weights = None, lr = 0.1, n_iter = 50, plot_loss_flag = False):
        
        if weights == None:
            weights = jnp.full((X_set.shape[0]), 1) / X_set.shape[0]
    
        # init mean with random element from set
        Y = X_set[np.random.randint(0, X_set.shape[0], (1,))][0] 
        
        if plot_loss_flag:
            plot_loss = []
            prev_loss = 0
            plato_iter = 0
            plato_reached = False
        
        for i in range(n_iter):
            
            # calculate loss
            loss = weighted_distance(Y, X_set, weights)
    
            if plot_loss_flag:
                if jnp.allclose(jnp.array(loss), jnp.array(prev_loss)):
                    if not plato_reached:
                        plato_iter = i
                        plato_reached = True
                else:
                    prev_loss = loss
                    plato_reached = False
        
            Y_grad = grad(weighted_distance, argnums= 0)(Y, X_set, weights)
            
            # calculate Riemannian gradient
            riem_grad_Y = Y_grad
            
            # update Y
            Y_step = Y - lr * riem_grad_Y
            
            # project new Y on manifold with retraction
            Y = Y_step
            
            if plot_loss_flag:
              # collect loss for plotting
              plot_loss.append(loss)
        
        if plot_loss_flag:
            print(f"Total loss: {weighted_distance(Y, X_set, weights)} got in {plato_iter} iterations")    
            fig, ax = plt.subplots()
            ax.plot(plot_loss)
            ax.set_xlabel("Iteration")
            ax.set_ylabel("Loss")
            plt.show()
        return Y
    

    You can launch it like this:

    d = 2
    m = 4
    X = jax.random.uniform(key, (m,d))
    euclidean_weighted_mean(X, weights = None, lr = 1e-3, n_iter = 100, plot_loss_flag = True)
    

    As you can see, I am calculating the weighted version of mean and that's where I use jaxopt. Let me define the global objective (just as an example): I want the weights have the value, which minimises the distance between the resulting mean and the desired point. In my case, I want the weights to influence the algorithm in such a way, that the resulting mean will be as close to X[0] as possible:

    def global_task_objective(w, X, target_point, lr, n_iter):
        x = euclidean_weighted_mean(X, w, lr = lr, n_iter = n_iter)
        loss = euclidean_distance(x, target_point)
        return loss, x
    
    target_point = X[0]
    
    w_init = jnp.array(np.random.randn(X.shape[0])) * jnp.square(2 / X.shape[0]) 
    
    lr = 1e-3
    n_iter = 100
    
    global_task_objective(w_init, X, target_point, lr, n_iter)
    solver = OptaxSolver(opt=optax.amsgrad(1e-2), fun=global_task_objective, has_aux=True)
    state = solver.init_state(w_init, X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    

    The problem emerges when I call

    w_init, state = solver.update(params=w_init, 
                                 state=state, 
                                 X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    
    image Meanwhile the official example with Ridge regression works perfectly. Any suggestions?
    opened by MarioAuditore 0
  • Custom loop pjit example

    Custom loop pjit example

    A MWE of how jax.experimental.pjit can be used in JAXopt (see also PR #346).

    NOTE: jax.experimental.pjit is not yet supported in Colab. However, this example illustrates how users with access to Google Cloud TPUs may use jax.experimental.pjit in combination with JAXopt solvers.

    pull ready 
    opened by fllinares 2
  • Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    This fixes #351 .

    @mblondel I couldn't use your suggestion of creating a new type of init LBFGSInit because the init_params variable is used for both init_state and update. Therefore I would have had to add case distinctions in the 2 functions which seemed unreasonable. Rather I took the approach I saw in some other iterative solvers which was to add an extra keyword argument to init_state, update and _value_and_grad_fun.

    I added a test to make sure that this runs, but I am not sure whether we need to add a test to make sure that it improves some cases. I also don't know whether we should test that differentiation is ok.

    opened by zaccharieramzi 5
  • Enable warm-starting the hessian approximation in L-BFGS

    Enable warm-starting the hessian approximation in L-BFGS

    Currently one can only provide an initial estimate of the solution, enable warm start of the iterates. But for quasi-Newton methods, it can also be a good idea to provide initial estimates of the hessian approximation, typically when solving multiple time a similar problem.

    This was for example done in HOAG by @fabianp (see https://github.com/fabianp/hoag/blob/master/hoag/hoag.py#L109).

    I am willing to implement this in the next few weeks.

    As I know it is of interest to them as well, cc-ing @marius311 and @mblondel

    opened by zaccharieramzi 2
  • Batched QP (and other optimization algorithm)

    Batched QP (and other optimization algorithm)

    I'm trying to make OSQP batchable (so I can make it a layer in neural networks, like OptNet), but I couldn't find any documentation yet about using vmap to solve batched version of optimization problems.

    opened by jn-tang 1
Releases(jaxopt-v0.5.5)
  • jaxopt-v0.5.5(Oct 20, 2022)

    New features

    • Added MAML example by Fabian Pedregosa based on initial code by Paul Vicol and Eric Jiang.
    • Added the possibility to stop LBFGS after a line search failure, by Zaccharie Ramzi.
    • Added gamma to LBFGS state, by Zaccharie Ramzi.
    • Added jaxopt.BFGS, by Mathieu Blondel.
    • Added value_and_grad option to all gradient-based solvers, by Mathieu Blondel.
    • Added Fenchel-Young loss, by Quentin Berthet.
    • Added projection_sparse_simplex, by Tianlin Liu.

    Bug fixes and enhancements

    • Fixed missing args,kwargs in resnet example, by Louis Béthune.
    • Corrected the implicit diff examples, by Zaccharie Ramzi.
    • Small optimization in l2-regularized semi-dual OT, by Mathieu Blondel.
    • Numerical stability improvements in jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Dtype consistency in LBFGS, by Alex Botev.

    Deprecations

    • jaxopt.QuadraticProgramming is now fully removed. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Alex Botev, Amir Saadat, Fabian Pedregosa, Louis Béthune, Mathieu Blondel, Quentin Berthet, Tianlin Liu, Zaccharie Ramzi.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.5(Aug 30, 2022)

    New features

    • Added optimal transport related projections: projection_transport, projection_birkhoff, kl_projection_transport, and kl_projection_birkhoff, by Mathieu Blondel (semi-dual formulation) and Tianlin Liu (dual formulation).

    Bug fixes and enhancements

    • Fix LaTeX rendering issue in notebooks, by Amélie Héliou.
    • Avoid gradient recompilations in zoom line search, by Mathieu Blondel.
    • Fix unused Jacobian issue in jaxopt.ScipyRootFinding, by Louis Béthune.
    • Use zoom line search by default in jaxopt.LBFGS and jaxopt.NonlinearCG, by Mathieu Blondel.
    • Pass tolerance argument to jaxopt.ScipyMinimize, by pipme.
    • Handle has_aux in jaxopt.LevenbergMarquardt, by Keunhong Park.
    • Add maxiter keyword argument in jaxopt.ScipyMinimize, by Fabian Pedregosa.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Keunhong Park, Fabian Pedregosa, pipme.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.3(Jun 28, 2022)

    New features

    • Added zoom line search in jaxopt.LBFGS, by Mathieu Blondel. It can be enabled with the linesearch="zoom" option.

    Bug fixes and enhancements

    • Added support for quadratic polynomial fun in jaxopt.BoxOSQP and jaxopt.OSQP, by Louis Béthune.
    • Added a notebook for the dataset distillation example, by Amélie Héliou.
    • Fixed wrong links and deprecation warnings in notebooks, by Fabian Pedregosa.
    • Changed losses to avoid roundoff, by Jack Valmadre.
    • Fixed init_params bug in multiclass_svm example, by Louis Béthune.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Fabian Pedregosa, Jack Valmadre.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.2(Jun 10, 2022)

  • jaxopt-v0.4.1(Jun 10, 2022)

    Bug fixes and enhancements

    • Improvements in jaxopt.LBFGS: fixed bug when using use_gamma=True, added stepsize option, strengthened tests, by Mathieu Blondel.
    • Fixed link in resnet notebook, by Fabian Pedregosa.

    Contributors

    Fabian Pedregosa, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4(May 24, 2022)

    New features

    • Added solver jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Added solver jaxopt.BoxCDQP, by Mathieu Blondel.
    • Added projection_hypercube, by Mathieu Blondel.

    Bug fixes and enhancements

    • Fixed solve_normal_cg when the linear operator is “nonsquare” (does not map to a space of same dimension), by Mathieu Blondel.
    • Fixed edge case in jaxopt.Bisection, by Mathieu Blondel.
    • Replaced deprecated tree_multimap with tree_map, by Fan Yang.
    • Added support for leaf cond pytrees in tree_where, by Felipe Llinares.
    • Added Python 3.10 support officially, by Jeppe Klitgaard.
    • In scipy wrappers, converted pytree leaves to jax arrays to determine their shape/dtype, by Roy Frostig.
    • Converted the “Resnet” and “Adversarial Training” examples to notebooks, by Fabian Pedregosa.

    Contributors

    Amir Saadat, Fabian Pedregosa, Fan Yang, Felipe Llinares, Jeppe Klitgaard, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3.1(Feb 28, 2022)

    New features

    • Pjit-based example of data parallel training using Flax, by Felipe Llinares.

    Bug fixes and enhancements

    • Support for GPU and state of the art adversarial training algorithm (PGD) on the robust_training.py example, by Fabian Pedregosa.
    • Update line search in LBFGS to use jit and unroll from LBFGS, by Ian Williamson.
    • Support dynamic maximum iteration count in iterative solvers, by Roy Frostig.
    • Fix tree_where for singleton pytrees, by Louis Béthune.
    • Remove QuadraticProg in projections and set init_params=None by default in QP solvers, by Louis Béthune.
    • Add missing 'value' attribute in LbfgsState, by Mathieu Blondel.

    Contributors

    Felipe Llinares, Fabian Pedregosa, Ian Williamson, Louis Bétune, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3(Jan 31, 2022)

    New features

    • jaxopt.LBFGS
    • jaxopt.BacktrackingLineSearch
    • jaxopt.GaussNewton
    • jaxopt.NonlinearCG

    Bug fixes and enhancements

    • Support implicit AD in higher-order differentiation.

    Contributors

    Amir Saadat, Fabian Pedregosa, Geoffrey Négiar, Hyunsung Lee, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.2(Dec 18, 2021)

    New features

    • Quadratic programming solvers jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP
    • Iterative refinement

    New examples

    • Resnet example with Flax and JAXopt.

    Bug fixes and enhancements

    • Prevent recompilation of loops in solver.run if executing without jit.
    • Prevents recomputation of gradient in OptaxSolver.
    • Make solver.update jittable and ensure output states are consistent.
    • Allow Callable for the stepsize argument in jaxopt.ProximalGradient, jaxopt.ProjectedGradient and jaxopt.GradientDescent.

    Deprecated features

    • jaxopt.QuadraticProgramming is deprecated and will be removed in v0.3. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Geoffrey Negiar, Louis Bethune, Mathieu Blondel, Vikas Sindhwani.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1.1(Oct 19, 2021)

    New features

    • Added solver jaxopt.ArmijoSGD
    • Added example Deep Equilibrium (DEQ) model in Flax with Anderson acceleration.
    • Added example Comparison of different SGD algorithms.

    Bug fixes

    • Allow non-jittable proximity operators in jaxopt.ProximalGradient
    • Raise an exception if a quadratic program is infeasible or unbounded

    Contributors

    Fabian Pedregosa, Louis Bethune, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1(Oct 14, 2021)

    Classes

    • jaxopt.AndersonAcceleration
    • jaxopt.AndersonWrapper
    • jaxopt.Bisection
    • jaxopt.BlockCoordinateDescent
    • jaxopt.FixedPointIteration
    • jaxopt.GradientDescent
    • jaxopt.MirrorDescent
    • jaxopt.OptaxSolver
    • jaxopt.PolyakSGD
    • jaxopt.ProjectedGradient
    • jaxopt.ProximalGradient
    • jaxopt.QuadraticProgramming
    • jaxopt.ScipyBoundedLeastSquares
    • jaxopt.ScipyBoundedMinimize
    • jaxopt.ScipyLeastSquares
    • jaxopt.ScipyMinimize
    • jaxopt.ScipyRootFinding
    • Implicit differentiation

    Examples

    • Binary kernel SVM with intercept.
    • Image classification example with Flax and JAXopt.
    • Image classification example with Haiku and JAXopt.
    • VAE example with Haiku and JAXopt.
    • Implicit differentiation of lasso.
    • Multiclass linear SVM (without intercept).
    • Non-negative matrix factorizaton (NMF) using alternating minimization.
    • Dataset distillation.
    • Implicit differentiation of ridge regression.
    • Robust training.
    • Anderson acceleration of gradient descent.
    • Anderson acceleration of block coordinate descent.
    • Anderson acceleration in application to Picard–Lindelöf theorem.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Robert Gower, Louis Bethune, Marco Cuturi, Mathieu Blondel, Peter Hawkins, Quentin Berthet, Roy Frostig, Ta-Chu Kao

    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
🐦 Quickly annotate data from the comfort of your Jupyter notebook

🐦 pigeon - Quickly annotate data on Jupyter Pigeon is a simple widget that lets you quickly annotate a dataset of unlabeled examples from the comfort

Anastasis Germanidis 647 Jan 05, 2023
This repository implements variational graph auto encoder by Thomas Kipf.

Variational Graph Auto-encoder in Pytorch This repository implements variational graph auto-encoder by Thomas Kipf. For details of the model, refer to

DaehanKim 215 Jan 02, 2023
Graph Neural Networks with Keras and Tensorflow 2.

Welcome to Spektral Spektral is a Python library for graph deep learning, based on the Keras API and TensorFlow 2. The main goal of this project is to

Daniele Grattarola 2.2k Jan 08, 2023
Clustering is a popular approach to detect patterns in unlabeled data

Visual Clustering Clustering is a popular approach to detect patterns in unlabeled data. Existing clustering methods typically treat samples in a data

Tarek Naous 24 Nov 11, 2022
Scalable machine learning based time series forecasting

mlforecast Scalable machine learning based time series forecasting. Install PyPI pip install mlforecast Optional dependencies If you want more functio

Nixtla 145 Dec 24, 2022
Matplotlib Image labeller for classifying images

mpl-image-labeller Use Matplotlib to label images for classification. Works anywhere Matplotlib does - from the notebook to a standalone gui! For more

Ian Hunt-Isaak 5 Sep 24, 2022
PyTorch implementation for View-Guided Point Cloud Completion

PyTorch implementation for View-Guided Point Cloud Completion

22 Jan 04, 2023
Efficient 6-DoF Grasp Generation in Cluttered Scenes

Contact-GraspNet Contact-GraspNet: Efficient 6-DoF Grasp Generation in Cluttered Scenes Martin Sundermeyer, Arsalan Mousavian, Rudolph Triebel, Dieter

NVIDIA Research Projects 148 Dec 28, 2022
PyTorch code for the paper "Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval".

Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval (M2HSE) PyTorch code fo

Xinlei-Pei 6 Dec 23, 2022
Generative Flow Networks for Discrete Probabilistic Modeling

Energy-based GFlowNets Code for Generative Flow Networks for Discrete Probabilistic Modeling by Dinghuai Zhang, Nikolay Malkin, Zhen Liu, Alexandra Vo

Narsil-Dinghuai Zhang 51 Dec 20, 2022
Plover-tapey-tape: an alternative to Plover’s built-in paper tape

plover-tapey-tape plover-tapey-tape is an alternative to Plover’s built-in paper

7 May 29, 2022
A mini lib that implements several useful functions binding to PyTorch in C++.

Torch-gather A mini library that implements several useful functions binding to PyTorch in C++. What does gather do? Why do we need it? When dealing w

maxwellzh 8 Sep 07, 2022
Development Kit for the SoccerNet Challenge

SoccerNetv2-DevKit Welcome to the SoccerNet-V2 Development Kit for the SoccerNet Benchmark and Challenge. This kit is meant as a help to get started w

Silvio Giancola 117 Dec 30, 2022
use machine learning to recognize gesture on raspberrypi

Raspberrypi_Gesture-Recognition use machine learning to recognize gesture on raspberrypi 說明 利用 tensorflow lite 訓練手部辨識模型 分辨 "剪刀"、"石頭"、"布" 之手勢 再將訓練模型匯入

1 Dec 10, 2021
Predictive Modeling on Electronic Health Records(EHR) using Pytorch

Predictive Modeling on Electronic Health Records(EHR) using Pytorch Overview Although there are plenty of repos on vision and NLP models, there are ve

81 Jan 01, 2023
The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection .

GCoNet The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection . Trained model Download final_gconet.pth

Qi Fan 46 Nov 17, 2022
Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral]

Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral] Learning to Disambiguate Strongly In

Zicong Fan 40 Dec 22, 2022
Multi Task Vision and Language

12-in-1: Multi-Task Vision and Language Representation Learning Please cite the following if you use this code. Code and pre-trained models for 12-in-

Facebook Research 712 Dec 19, 2022
An NVDA add-on to split screen reader and audio from other programs to different sound channels

An NVDA add-on to split screen reader and audio from other programs to different sound channels (add-on idea credit: Tony Malykh)

Joseph Lee 7 Dec 25, 2022
Machine learning evaluation metrics, implemented in Python, R, Haskell, and MATLAB / Octave

Note: the current releases of this toolbox are a beta release, to test working with Haskell's, Python's, and R's code repositories. Metrics provides i

Ben Hamner 1.6k Dec 26, 2022