Hardware accelerated, batchable and differentiable optimizers in JAX.

Overview

JAXopt

Installation | Examples | References

Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

Installation

JAXopt can be installed with pip directly from github, with the following command:

$ pip install git+https://github.com/google/jaxopt

Alternatively, it can be be installed from sources with the following command:

$ python setup.py install

References

Our implicit differentiation framework is described in this paper. To cite it:

@article{jaxopt_implicit_diff,
  title={Efficient and Modular Implicit Differentiation},
  author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe},
  journal={arXiv preprint arXiv:2105.15183},
  year={2021}
}

Disclaimer

JAXopt is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

Comments
  • Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat optimizer runs exceptionally slow (~30 seconds for 30 iterations) until I turn on verbose==True (~1 second for 30 iterations. Any idea what may be going on? enabling JIT seems to have no impact. Was hoping to use this for a real-time system but even at 1 second things are way too slow.

    opened by pablovela5620 23
  • implementation of Fletcher-Reeves Algorithm

    implementation of Fletcher-Reeves Algorithm

    • Polak-Ribiere Method; To my knowledge, it was quite successful to use conjugate gradient variants on general nonconstrained optimization

    This PR depends on Line Search of PR #128.

    • Beta division is required to guarantee strong Wolfe Condition, but (i don't know) it raises error..
    pull ready 
    opened by ita9naiwa 17
  • vmap support in QPs

    vmap support in QPs

    Hi, I experience some pb with projection_polyhedron

    import numpy as np
    import matplotlib.pyplot as plt
    
    import jax
    import jax.numpy as jnp
    
    import jaxopt
    from jaxopt.projection import projection_l2_ball, projection_box, projection_l1_ball, projection_polyhedron
    
    def myproj3(x):
        A = jnp.array([[1.0, 1.0]])
        b = jnp.array([1.0])
        G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
        h = jnp.array([0.0, 0.0])    
        x = projection_polyhedron(x,hyperparams = (A, b, G, h))
        return x
    
    rng_key = jax.random.PRNGKey(42)
    x = jax.random.uniform(rng_key, (5000,2), minval=-3,maxval=3)
    p1_x=jax.vmap(myproj3, in_axes=(0,None))(x)
    fig, ax = plt.subplots(figsize=(5,5))
    ax.scatter(x[:,0],x[:,1],s=0.5)
    ax.scatter(p1_x[:,0],p1_x[:,1],s=0.5,c='g')
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    plot.show()
    

    First, I had to install cvxpy #!pip install cvxpy Then, I got this error

    TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[2])>with<BatchTrace(level=1/1)>
      with val = DeviceArray([[-2.37103211,  2.33759997],
                              [ 2.76953806, -2.37750394],
                              [-0.87246632,  0.73224625],
                              ...,
                              [ 2.29799773,  2.81894884],
                              [ 2.4022714 ,  0.80693103],
                              [-0.41563116,  2.83898531]], dtype=float64)
           batch_dim = 0
    

    Is anyone has an hint? Thanks

    enhancement 
    opened by jecampagne 12
  • KKT conditions when the primal solution is a pytree

    KKT conditions when the primal solution is a pytree

    Hi, Congrats on the great tool! Inspired by the QuadraticProgramming example I built a code that differentiates through KKT conditions. My code works whenever the primal solution variable is a jnp array, but not when it's a generic pytree. Giving me the following issue:

    TypeError: Tree structure of cotangent input PyTreeDef(([(*, *), (), (*, *)], *, None)), does not match structure of primal output PyTreeDef(([(*, *), (), (*, *), (*, *), (), (*, *)], *, None))

    where I'm pretty sure [(*, *), (), (*, *)] represents the primal solution and PyTreeDef(([(*, *), (), (*, *)], *, None)) could represent the optimality function.

    I was able to make it work by storing the primal solution in a single jnp array and reshaping it into the appropriate pytree whenever needed, but it's not clean or efficient. I was wondering if there's a bug in the current codebase (I only found tests for single jnp arrays) or I'm misusing the interface (I'm not a jax expert).

    To make it easier to reproduce I modified the quadratic_prog.py file by making the model return a list of one array instead of an array for the primal variables (leaving both dual variables the same). Then I modified the obj_fun, eq_fun and ineq_fun to use primal_var[0] instead of primal_var. If I understand correctly, this should still work. However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is: ([DeviceArray([ 0.43999994, -1.3199999 ], dtype=float32), DeviceArray([-0.44000003, 1.32 ], dtype=float32)], DeviceArray([2.9802322e-08], dtype=float32), None)

    Looking at the numbers of the problem I believe [0.44,-1.32] is the gradient of the obj_fun w.r.t. the primal and [-0.44,+1.32] the gradient of the equality constraint w.r.t. the primal times the dual. They should have been added up together to have [0,0] as expected. I feel this may be fundamentally the same problem I was facing in my own research code since there I also found one of the values had the shape of the primal variable twice instead of once.

    Notice also thatthe test on the line just above (checking that the primal solution is correct) still holds provided we check sol[0][0] instead of sol[0] (since sol[0] is now a 1-element list).

    Is differentiation through KKT supposed to work for general pytrees? If so, what should I have done to make it work in the quadratic_prog.py example?

    Thanks!

    opened by FerranAlet 11
  • Hot fix: corrected condition in lbfgs

    Hot fix: corrected condition in lbfgs

    The feature I had introduced in https://github.com/google/jaxopt/pull/323 was failing when the run function was jitted and was a no-op when not because of the following reason:

     ~True == -2  # this is True
    

    Therefore when jitted it was complaining about different types in a condition function, and when not jitted it was equivalent to always being False.

    EDIT

    Actually I am still running into an error when jitted, so will continue to investigate.

    The gist of the error is Abstract tracer value encountered where concrete value is expected, basically doing (not self.stop_if_linesearch_fails | ~state.failed_linesearch) is not allowed because one is a bool and the other is an abstract value.

    pull ready 
    opened by zaccharieramzi 7
  • Issue with gradients wrt optimality fn parameters through root finding vjp

    Issue with gradients wrt optimality fn parameters through root finding vjp

    First of all, thanks a lot for this library! Really useful tools! I'm interested in getting at least 2nd order gradients through root finding, and I'm finding an odd behavior that I wanted to report.

    Maybe I'm doing something wrong, but in the following schematic case I silently get the wrong gradients:

    def inv_f(x, aux):
      bisec = Bisection(optimality_fun=F, lower=0.0, upper=1., 
                        check_bracket=False, unroll=True)
      return bisec.run(aux=aux).params
    
    # Here I extract the value part of the vjp, but the grad part also gives wrong results
    test_fn = lambda aux: jax.value_and_grad(inv_f)(0.5, aux)[0] 
    
    jax.grad(test_fn)(1.) # Returns 0 instead of the expected gradients
    

    Here I'm only trying to get gradients of the value returned by jax.value_and_grad, but the gradients of the gradients returned by jax.value_and_grad are also wrong (but not as obvious).

    I made a small demo notebook that reproduces this issue here.

    As a reference I've also implemented my own implicit gradients, bypassing the jaxopt ones, and they seem to give me the correct answer.

    Reading the source code of jaxopt, it is not immediatly obvious to me why this doesn't work... Sorry I couldn't directly suggest a PR, but I hope this report is still useful (and that I'm not just using jaxopt wrong).

    bug 
    opened by EiffL 7
  • misc improvements to robust training example

    misc improvements to robust training example

    main changes:

    • Fixes #134 by normalizing in-place.
    • Plot convergence curves for both clean and adversarial accuracy.
    • Replace the fast-sign-gradient method by the much more powerful PGD method.
    • Be able to select different datasets.
    • Homogeneize API wrt to the other examples. For example, this now uses the same load_dataset, CNN, loss_fun, accuracy than flax_image_classif.py . Most of the command line flags have also been homogeneized.
    pull ready 
    opened by fabianp 7
  • Bisection hanging

    Bisection hanging

    I am trying to jaxopt.Bisection to replace the use of scipy.optimize.bisect in a computational model but Bisection hangs when I run my code.

    The basic structure includes 2 functions that are both jitted (so I assume it should be able to compile ok):

    @jit
    def f1(parameters):
        ....
        return jax.numpy.array([a,b,c])
    
    @jit
    def opt_fun(x):
        f1(x,params)
        .... 
        return float_value
    

    when I call scipy.optimize.bisect(opt_fun,x0,x1) it runs with no issue but jaxopt.Bisection(opt_fun,x0,x1).run(None) hangs with with~10% cpu usage and55% memory usage on i9 2018 macbook pro with 32GB of memory.

    I acknowledge I may be using this incorrectly and that this is possibly not the intended use case but any direction would be very helpful. My intention is to use this computational model with numpyro in the future and having a jax version of the bisection root finding would be incredibly helpful.

    opened by jjruby09 7
  • Incompatible shape in solve_normal_cg

    Incompatible shape in solve_normal_cg

    When A.shape = (N, P) for N != P, I run into shape errors when trying to use solve_normal_cg for fitting the normal equations.

    I have a small reproducible example below for N > P, but the error holds for when P > N.

    import jax.numpy as jnp
    import numpy as np
    N = 1000
    P = 3
    prob = np.random.uniform(0.01, 0.5, size=P)
    h2g = 0.1
    X = np.random.binomial(2, p=prob, size=(N, P))
    b = np.random.normal(size=(P)) * np.sqrt(h2g / P)
    y = X @ b + np.sqrt(1 - h2g) * np.random.normal(size=(N,))
    
    import jaxopt as jopt
    jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    Input In [11], in <module>
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:151, in solve_normal_cg(matvec, b, ridge, init, **kwargs)
        148 if ridge is not None:
        149   _matvec = _make_ridge_matvec(_matvec, ridge=ridge)
    --> 151 Ab = _rmatvec(matvec, b)
        153 return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:114, in _rmatvec(matvec, x)
        112 def _rmatvec(matvec, x):
        113   """Computes A^T x, from matvec(x) = A x, where A is square."""
    --> 114   transpose = jax.linear_transpose(matvec, x)
        115   return transpose(x)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/api.py:2211, in linear_transpose(fun, reduce_axes, *primals)
       2208 in_dtypes = map(dtypes.dtype, in_avals)
       2210 in_pvals = map(pe.PartialVal.unknown, in_avals)
    -> 2211 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
       2212                                              instantiate=True)
       2213 out_avals, _ = unzip2(out_pvals)
       2214 out_dtypes = map(dtypes.dtype, out_avals)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:505, in trace_to_jaxpr(fun, pvals, instantiate)
        503 with core.new_main(JaxprTrace) as main:
        504   fun = trace_to_subjaxpr(fun, main, instantiate)
    --> 505   jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
        506   assert not env
        507   del main, fun, env
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
        163 gen = gen_static_args = out_store = None
        165 try:
    --> 166   ans = self.f(*args, **dict(self.params, **kwargs))
        167 except:
        168   # Some transformations yield from inside context managers, so we have to
        169   # interrupt them before reraising the exception. Otherwise they will only
        170   # get garbage-collected at some later time, running their cleanup tasks only
        171   # after this exception is handled, which can corrupt the global state.
        172   while stack:
    
    Input In [11], in <lambda>(x)
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4196, in dot(a, b, precision)
       4194   return lax.mul(a, b)
       4195 if _max(a_ndim, b_ndim) <= 2:
    -> 4196   return lax.dot(a, b, precision=precision)
       4198 if b_ndim == 1:
       4199   contract_dims = ((a_ndim - 1,), (0,))
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/lax/lax.py:667, in dot(lhs, rhs, precision, preferred_element_type)
        664   return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
        665                      precision=precision, preferred_element_type=preferred_element_type)
        666 else:
    --> 667   raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
        668       lhs.shape, rhs.shape))
    
    TypeError: Incompatible shapes for dot: got (1000, 3) and (1000,).
    
    opened by quattro 6
  • Initial stepsize not exposed in LBFGS constructor [question/bug?]

    Initial stepsize not exposed in LBFGS constructor [question/bug?]

    I see that LbfgsState contains a stepsize and that LBFGS.init_state hard-codes it to 1. I also see that the LBFGS.update method performs a line search in which the initial step size is set from this LBFGS state.

    I have a particularly ill-conditioned problem that requires tiny initial steps, but I was surprised that the initial stepsize could not be set in the LBFGS constructor or elsewhere as far as I could see. Is this an oversight or an intentional part of the design? If it's intentional, is there an idiomatic way to set an initial stepsize when using LBFGS.run that I have overlooked?

    Thanks in advance, and thanks for a really cool library.

    opened by erdmann 6
  • Infinities and NaNs in quadratic_prog when c=0

    Infinities and NaNs in quadratic_prog when c=0

    Hi,

    I'm using QuadraticProgramming in the special case of c=0 (all zeros as a vector). AFAIK this is still well-defined, as it's just minimizing l2 norm squared of the primal subject to some equality constraints (I don't have inequalities).

    However, both my research code and the following modification of this test diverge even for a single step (maxiter=1).

    The modification just involves setting c=0, so:

    def test_qp_eq_only_c_zero(self):
      Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
      c = jnp.array([0.0, 0.0]) #ONLY CHANGE
      A = jnp.array([[1.0, 1.0]])
      b = jnp.array([1.0])
      qp = QuadraticProgramming(tol=1e-7)
      hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
      sol = qp.run(**hyperparams).params
      self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
      self._check_derivative_A_and_b(qp, hyperparams, A, b)
    

    Is there a way to fix it? If it involves calling another linear solver, is there a way to specify the solver from the high-level QP function? I haven't seen it.

    Thanks!

    opened by FerranAlet 6
  • OptaxSolver Error: too many positional arguments

    OptaxSolver Error: too many positional arguments

    Hello! I tried to implement the example of implicit differentiation as shown here but with my own functions. The task is to find mean for a set of vectors named X via gradient descent.

    import numpy as np
    import matplotlib.pyplot as plt
    from tqdm import tqdm
    
    import jax
    import jax.numpy as jnp
    from jax import grad, random, jit
    from jax import jacobian, hessian, jacfwd, jacrev
    key = random.PRNGKey(0)
    
    import jaxopt
    from jaxopt import implicit_diff
    from jaxopt import linear_solve
    from jaxopt import OptaxSolver, GradientDescent
    import optax
    
    def euclidean_distance(a, b):
        """
        Squared Euclidean distance
        """
        return jnp.inner(a - b, a - b)
    
    def weighted_distance(x, X, w):
        loss = 0
        for i, obj in enumerate(X):
            loss += w[i] * euclidean_distance(obj, x)
        return loss
    
    def identical(Y, Y_grad):
        return Y
    

    Algorithm for finding mean:

    # Mean calculation for manifolds with gradient descent
    @implicit_diff.custom_root(jax.grad(weighted_distance))
    def euclidean_weighted_mean(X_set, weights = None, lr = 0.1, n_iter = 50, plot_loss_flag = False):
        
        if weights == None:
            weights = jnp.full((X_set.shape[0]), 1) / X_set.shape[0]
    
        # init mean with random element from set
        Y = X_set[np.random.randint(0, X_set.shape[0], (1,))][0] 
        
        if plot_loss_flag:
            plot_loss = []
            prev_loss = 0
            plato_iter = 0
            plato_reached = False
        
        for i in range(n_iter):
            
            # calculate loss
            loss = weighted_distance(Y, X_set, weights)
    
            if plot_loss_flag:
                if jnp.allclose(jnp.array(loss), jnp.array(prev_loss)):
                    if not plato_reached:
                        plato_iter = i
                        plato_reached = True
                else:
                    prev_loss = loss
                    plato_reached = False
        
            Y_grad = grad(weighted_distance, argnums= 0)(Y, X_set, weights)
            
            # calculate Riemannian gradient
            riem_grad_Y = Y_grad
            
            # update Y
            Y_step = Y - lr * riem_grad_Y
            
            # project new Y on manifold with retraction
            Y = Y_step
            
            if plot_loss_flag:
              # collect loss for plotting
              plot_loss.append(loss)
        
        if plot_loss_flag:
            print(f"Total loss: {weighted_distance(Y, X_set, weights)} got in {plato_iter} iterations")    
            fig, ax = plt.subplots()
            ax.plot(plot_loss)
            ax.set_xlabel("Iteration")
            ax.set_ylabel("Loss")
            plt.show()
        return Y
    

    You can launch it like this:

    d = 2
    m = 4
    X = jax.random.uniform(key, (m,d))
    euclidean_weighted_mean(X, weights = None, lr = 1e-3, n_iter = 100, plot_loss_flag = True)
    

    As you can see, I am calculating the weighted version of mean and that's where I use jaxopt. Let me define the global objective (just as an example): I want the weights have the value, which minimises the distance between the resulting mean and the desired point. In my case, I want the weights to influence the algorithm in such a way, that the resulting mean will be as close to X[0] as possible:

    def global_task_objective(w, X, target_point, lr, n_iter):
        x = euclidean_weighted_mean(X, w, lr = lr, n_iter = n_iter)
        loss = euclidean_distance(x, target_point)
        return loss, x
    
    target_point = X[0]
    
    w_init = jnp.array(np.random.randn(X.shape[0])) * jnp.square(2 / X.shape[0]) 
    
    lr = 1e-3
    n_iter = 100
    
    global_task_objective(w_init, X, target_point, lr, n_iter)
    solver = OptaxSolver(opt=optax.amsgrad(1e-2), fun=global_task_objective, has_aux=True)
    state = solver.init_state(w_init, X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    

    The problem emerges when I call

    w_init, state = solver.update(params=w_init, 
                                 state=state, 
                                 X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    
    image Meanwhile the official example with Ridge regression works perfectly. Any suggestions?
    opened by MarioAuditore 0
  • Custom loop pjit example

    Custom loop pjit example

    A MWE of how jax.experimental.pjit can be used in JAXopt (see also PR #346).

    NOTE: jax.experimental.pjit is not yet supported in Colab. However, this example illustrates how users with access to Google Cloud TPUs may use jax.experimental.pjit in combination with JAXopt solvers.

    pull ready 
    opened by fllinares 2
  • Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    This fixes #351 .

    @mblondel I couldn't use your suggestion of creating a new type of init LBFGSInit because the init_params variable is used for both init_state and update. Therefore I would have had to add case distinctions in the 2 functions which seemed unreasonable. Rather I took the approach I saw in some other iterative solvers which was to add an extra keyword argument to init_state, update and _value_and_grad_fun.

    I added a test to make sure that this runs, but I am not sure whether we need to add a test to make sure that it improves some cases. I also don't know whether we should test that differentiation is ok.

    opened by zaccharieramzi 5
  • Enable warm-starting the hessian approximation in L-BFGS

    Enable warm-starting the hessian approximation in L-BFGS

    Currently one can only provide an initial estimate of the solution, enable warm start of the iterates. But for quasi-Newton methods, it can also be a good idea to provide initial estimates of the hessian approximation, typically when solving multiple time a similar problem.

    This was for example done in HOAG by @fabianp (see https://github.com/fabianp/hoag/blob/master/hoag/hoag.py#L109).

    I am willing to implement this in the next few weeks.

    As I know it is of interest to them as well, cc-ing @marius311 and @mblondel

    opened by zaccharieramzi 2
  • Batched QP (and other optimization algorithm)

    Batched QP (and other optimization algorithm)

    I'm trying to make OSQP batchable (so I can make it a layer in neural networks, like OptNet), but I couldn't find any documentation yet about using vmap to solve batched version of optimization problems.

    opened by jn-tang 1
Releases(jaxopt-v0.5.5)
  • jaxopt-v0.5.5(Oct 20, 2022)

    New features

    • Added MAML example by Fabian Pedregosa based on initial code by Paul Vicol and Eric Jiang.
    • Added the possibility to stop LBFGS after a line search failure, by Zaccharie Ramzi.
    • Added gamma to LBFGS state, by Zaccharie Ramzi.
    • Added jaxopt.BFGS, by Mathieu Blondel.
    • Added value_and_grad option to all gradient-based solvers, by Mathieu Blondel.
    • Added Fenchel-Young loss, by Quentin Berthet.
    • Added projection_sparse_simplex, by Tianlin Liu.

    Bug fixes and enhancements

    • Fixed missing args,kwargs in resnet example, by Louis Béthune.
    • Corrected the implicit diff examples, by Zaccharie Ramzi.
    • Small optimization in l2-regularized semi-dual OT, by Mathieu Blondel.
    • Numerical stability improvements in jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Dtype consistency in LBFGS, by Alex Botev.

    Deprecations

    • jaxopt.QuadraticProgramming is now fully removed. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Alex Botev, Amir Saadat, Fabian Pedregosa, Louis Béthune, Mathieu Blondel, Quentin Berthet, Tianlin Liu, Zaccharie Ramzi.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.5(Aug 30, 2022)

    New features

    • Added optimal transport related projections: projection_transport, projection_birkhoff, kl_projection_transport, and kl_projection_birkhoff, by Mathieu Blondel (semi-dual formulation) and Tianlin Liu (dual formulation).

    Bug fixes and enhancements

    • Fix LaTeX rendering issue in notebooks, by Amélie Héliou.
    • Avoid gradient recompilations in zoom line search, by Mathieu Blondel.
    • Fix unused Jacobian issue in jaxopt.ScipyRootFinding, by Louis Béthune.
    • Use zoom line search by default in jaxopt.LBFGS and jaxopt.NonlinearCG, by Mathieu Blondel.
    • Pass tolerance argument to jaxopt.ScipyMinimize, by pipme.
    • Handle has_aux in jaxopt.LevenbergMarquardt, by Keunhong Park.
    • Add maxiter keyword argument in jaxopt.ScipyMinimize, by Fabian Pedregosa.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Keunhong Park, Fabian Pedregosa, pipme.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.3(Jun 28, 2022)

    New features

    • Added zoom line search in jaxopt.LBFGS, by Mathieu Blondel. It can be enabled with the linesearch="zoom" option.

    Bug fixes and enhancements

    • Added support for quadratic polynomial fun in jaxopt.BoxOSQP and jaxopt.OSQP, by Louis Béthune.
    • Added a notebook for the dataset distillation example, by Amélie Héliou.
    • Fixed wrong links and deprecation warnings in notebooks, by Fabian Pedregosa.
    • Changed losses to avoid roundoff, by Jack Valmadre.
    • Fixed init_params bug in multiclass_svm example, by Louis Béthune.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Fabian Pedregosa, Jack Valmadre.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.2(Jun 10, 2022)

  • jaxopt-v0.4.1(Jun 10, 2022)

    Bug fixes and enhancements

    • Improvements in jaxopt.LBFGS: fixed bug when using use_gamma=True, added stepsize option, strengthened tests, by Mathieu Blondel.
    • Fixed link in resnet notebook, by Fabian Pedregosa.

    Contributors

    Fabian Pedregosa, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4(May 24, 2022)

    New features

    • Added solver jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Added solver jaxopt.BoxCDQP, by Mathieu Blondel.
    • Added projection_hypercube, by Mathieu Blondel.

    Bug fixes and enhancements

    • Fixed solve_normal_cg when the linear operator is “nonsquare” (does not map to a space of same dimension), by Mathieu Blondel.
    • Fixed edge case in jaxopt.Bisection, by Mathieu Blondel.
    • Replaced deprecated tree_multimap with tree_map, by Fan Yang.
    • Added support for leaf cond pytrees in tree_where, by Felipe Llinares.
    • Added Python 3.10 support officially, by Jeppe Klitgaard.
    • In scipy wrappers, converted pytree leaves to jax arrays to determine their shape/dtype, by Roy Frostig.
    • Converted the “Resnet” and “Adversarial Training” examples to notebooks, by Fabian Pedregosa.

    Contributors

    Amir Saadat, Fabian Pedregosa, Fan Yang, Felipe Llinares, Jeppe Klitgaard, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3.1(Feb 28, 2022)

    New features

    • Pjit-based example of data parallel training using Flax, by Felipe Llinares.

    Bug fixes and enhancements

    • Support for GPU and state of the art adversarial training algorithm (PGD) on the robust_training.py example, by Fabian Pedregosa.
    • Update line search in LBFGS to use jit and unroll from LBFGS, by Ian Williamson.
    • Support dynamic maximum iteration count in iterative solvers, by Roy Frostig.
    • Fix tree_where for singleton pytrees, by Louis Béthune.
    • Remove QuadraticProg in projections and set init_params=None by default in QP solvers, by Louis Béthune.
    • Add missing 'value' attribute in LbfgsState, by Mathieu Blondel.

    Contributors

    Felipe Llinares, Fabian Pedregosa, Ian Williamson, Louis Bétune, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3(Jan 31, 2022)

    New features

    • jaxopt.LBFGS
    • jaxopt.BacktrackingLineSearch
    • jaxopt.GaussNewton
    • jaxopt.NonlinearCG

    Bug fixes and enhancements

    • Support implicit AD in higher-order differentiation.

    Contributors

    Amir Saadat, Fabian Pedregosa, Geoffrey Négiar, Hyunsung Lee, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.2(Dec 18, 2021)

    New features

    • Quadratic programming solvers jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP
    • Iterative refinement

    New examples

    • Resnet example with Flax and JAXopt.

    Bug fixes and enhancements

    • Prevent recompilation of loops in solver.run if executing without jit.
    • Prevents recomputation of gradient in OptaxSolver.
    • Make solver.update jittable and ensure output states are consistent.
    • Allow Callable for the stepsize argument in jaxopt.ProximalGradient, jaxopt.ProjectedGradient and jaxopt.GradientDescent.

    Deprecated features

    • jaxopt.QuadraticProgramming is deprecated and will be removed in v0.3. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Geoffrey Negiar, Louis Bethune, Mathieu Blondel, Vikas Sindhwani.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1.1(Oct 19, 2021)

    New features

    • Added solver jaxopt.ArmijoSGD
    • Added example Deep Equilibrium (DEQ) model in Flax with Anderson acceleration.
    • Added example Comparison of different SGD algorithms.

    Bug fixes

    • Allow non-jittable proximity operators in jaxopt.ProximalGradient
    • Raise an exception if a quadratic program is infeasible or unbounded

    Contributors

    Fabian Pedregosa, Louis Bethune, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1(Oct 14, 2021)

    Classes

    • jaxopt.AndersonAcceleration
    • jaxopt.AndersonWrapper
    • jaxopt.Bisection
    • jaxopt.BlockCoordinateDescent
    • jaxopt.FixedPointIteration
    • jaxopt.GradientDescent
    • jaxopt.MirrorDescent
    • jaxopt.OptaxSolver
    • jaxopt.PolyakSGD
    • jaxopt.ProjectedGradient
    • jaxopt.ProximalGradient
    • jaxopt.QuadraticProgramming
    • jaxopt.ScipyBoundedLeastSquares
    • jaxopt.ScipyBoundedMinimize
    • jaxopt.ScipyLeastSquares
    • jaxopt.ScipyMinimize
    • jaxopt.ScipyRootFinding
    • Implicit differentiation

    Examples

    • Binary kernel SVM with intercept.
    • Image classification example with Flax and JAXopt.
    • Image classification example with Haiku and JAXopt.
    • VAE example with Haiku and JAXopt.
    • Implicit differentiation of lasso.
    • Multiclass linear SVM (without intercept).
    • Non-negative matrix factorizaton (NMF) using alternating minimization.
    • Dataset distillation.
    • Implicit differentiation of ridge regression.
    • Robust training.
    • Anderson acceleration of gradient descent.
    • Anderson acceleration of block coordinate descent.
    • Anderson acceleration in application to Picard–Lindelöf theorem.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Robert Gower, Louis Bethune, Marco Cuturi, Mathieu Blondel, Peter Hawkins, Quentin Berthet, Roy Frostig, Ta-Chu Kao

    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
yolox_backbone is a deep-learning library and is a collection of YOLOX Backbone models.

YOLOX-Backbone yolox-backbone is a deep-learning library and is a collection of YOLOX backbone models. Install pip install yolox-backbone Load a Pret

Yonghye Kwon 21 Dec 28, 2022
Reimplementation of Dynamic Multi-scale filters for Semantic Segmentation.

Paddle implementation of Dynamic Multi-scale filters for Semantic Segmentation.

Hongqiang.Wang 2 Nov 01, 2021
a generic C++ library for image analysis

VIGRA Computer Vision Library Copyright 1998-2013 by Ullrich Koethe This file is part of the VIGRA computer vision library. You may use,

Ullrich Koethe 378 Dec 30, 2022
Lava-DL, but with PyTorch-Lightning flavour

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Sami BARCHID 4 Oct 31, 2022
HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty

HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty Giorgio Cantarini, Francesca Odone, Nicoletta Noceti, Federi

18 Aug 02, 2022
Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch

Enformer - Pytorch (wip) Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. The original tensorflow

Phil Wang 235 Dec 27, 2022
Code for "Causal autoregressive flows" - AISTATS, 2021

Code for "Causal Autoregressive Flow" This repository contains code to run and reproduce experiments presented in Causal Autoregressive Flows, present

Ricardo Pio Monti 35 Dec 16, 2022
Adaptation through prediction: multisensory active inference torque control

Adaptation through prediction: multisensory active inference torque control Submitted to IEEE Transactions on Cognitive and Developmental Systems Abst

Cristian Meo 1 Nov 07, 2022
Fast and robust clustering of point clouds generated with a Velodyne sensor.

Depth Clustering This is a fast and robust algorithm to segment point clouds taken with Velodyne sensor into objects. It works with all available Velo

Photogrammetry & Robotics Bonn 957 Dec 21, 2022
Everything about being a TA for ITP/AP course!

تی‌ای بودن! تی‌ای یا دستیار استاد از نقش‌های رایج بین دانشجویان مهندسی است، این ریپوزیتوری قرار است نکات مهم درمورد تی‌ای بودن و تی ای شدن را به ما نش

<a href=[email protected]"> 14 Sep 10, 2022
A privacy-focused, intelligent security camera system.

Self-Hosted Home Security Camera System A privacy-focused, intelligent security camera system. Features: Multi-camera support w/ minimal configuration

Scott Barnes 175 Jan 01, 2023
A Structured Self-attentive Sentence Embedding

Structured Self-attentive sentence embeddings Implementation for the paper A Structured Self-Attentive Sentence Embedding, which was published in ICLR

Kaushal Shetty 488 Nov 28, 2022
The openspoor package is intended to allow easy transformation between different geographical and topological systems commonly used in Dutch Railway

Openspoor The openspoor package is intended to allow easy transformation between different geographical and topological systems commonly used in Dutch

7 Aug 22, 2022
An implementation of based on pytorch and mmcv

FisherPruning-Pytorch An implementation of Group Fisher Pruning for Practical Network Compression based on pytorch and mmcv Main Functions Pruning f

Peng Lu 15 Dec 17, 2022
PyTorch implementation for SDEdit: Image Synthesis and Editing with Stochastic Differential Equations

SDEdit: Image Synthesis and Editing with Stochastic Differential Equations Project | Paper | Colab PyTorch implementation of SDEdit: Image Synthesis a

536 Jan 05, 2023
An executor that loads ONNX models and embeds documents using the ONNX runtime.

ONNXEncoder An executor that loads ONNX models and embeds documents using the ONNX runtime. Usage via Docker image (recommended) from jina import Flow

Jina AI 2 Mar 15, 2022
Code and Data for the paper: Molecular Contrastive Learning with Chemical Element Knowledge Graph [AAAI 2022]

Knowledge-enhanced Contrastive Learning (KCL) Molecular Contrastive Learning with Chemical Element Knowledge Graph [ AAAI 2022 ]. We construct a Chemi

Fangyin 58 Dec 26, 2022
Object-aware Contrastive Learning for Debiased Scene Representation

Object-aware Contrastive Learning Official PyTorch implementation of "Object-aware Contrastive Learning for Debiased Scene Representation" by Sangwoo

43 Dec 14, 2022
使用yolov5训练自己数据集(详细过程)并通过flask部署

使用yolov5训练自己的数据集(详细过程)并通过flask部署 依赖库 torch torchvision numpy opencv-python lxml tqdm flask pillow tensorboard matplotlib pycocotools Windows,请使用 pycoc

HB.com 19 Dec 28, 2022
CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation

CPT This repository contains code and checkpoints for CPT. CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Gener

fastNLP 341 Dec 29, 2022