Optimal Transport Tools (OTT), A toolbox for all things Wasserstein.

Related tags

Text Data & NLPott
Overview

Tests

logo

Optimal Transport Tools (OTT), A toolbox for all things Wasserstein.

See full documentation for detailed info on the toolbox.

The goal of OTT is to provide sturdy, versatile and efficient optimal transport solvers, taking advantage of JAX features, such as JIT, auto-vectorization and implicit differentiation.

A typical OT problem has two ingredients: a pair of weight vectors a and b (one for each measure), with a ground cost matrix that is either directly given, or derived as the pairwise evaluation of a cost function on pairs of points taken from two measures. The main design choice in OTT comes from encapsulating the cost in a Geometry object, and bundle it with a few useful operations (notably kernel applications). The most common geometry is that of two clouds of vectors compared with the squared Euclidean distance, as illustrated in the example below:

Example

import jax
import jax.numpy as jnp
from ott.tools import transport
# Samples two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0),4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings via Sinkhorn algorithm.
ot = transport.Transport(x, y, a=a, b=b)
P = ot.matrix

The call to sinkhorn above works out the optimal transport solution by storing its output. The transport matrix can be instantiated using those optimal solutions and the Geometry again. That transoprt matrix links each point from the first point cloud to one or more points from the second, as illustrated below.

obtained coupling

To be more precise, the sinkhorn algorithm operates on the Geometry, taking into account weights a and b, to solve the OT problem, produce a named tuple that contains two optimal dual potentials f and g (vectors of the same size as a and b), the objective reg_ot_cost and a log of the errors of the algorithm as it converges, and a converged flag.

Overall description of source code

Currently implements the following classes and functions:

  • In the geometry folder,

    • The CostFn class in costs.py and its descendants define cost functions between points. Two simple costs are currently provided, Euclidean between vectors, and Bures, between a pair of mean vector and covariance (p.d.) matrix.

    • The Geometry class in geometry.py and its descendants describe a cost structure between two measures. That cost structure is accessed through various member functions, either used when running the Sinkhorn algorithm (typically kernel multiplications, or log-sum-exp row/column-wise application) or after (to apply the OT matrix to a vector).

      • In its generic Geometry implementation, as in geometry.py, an object can be initialized with either a cost_matrix along with an epsilon regularization parameter (or scheduler), or with a kernel_matrix.

      • If one wishes to compute OT between two weighted point clouds and endowed with a given cost function (e.g. Euclidean) , the PointCloud class in pointcloud.py can be used to define the corresponding kernel . When the number of these points grows very large, this geometry can be instantiated with an online=True parameter, to avoid storing the kernel matrix and choose instead to recompute the matrix on the fly at each application.

      • Simlarly, if all measures to be considered are supported on a separable grid (e.g. ), and the cost is separable along all axis, i.e. the cost between two points on that grid is equal to the sum of (possibly different) cost functions evaluated on each of the pairs of coordinates, then the application of the kernel is much simplified, both in log space or on the histograms themselves. This particular case is exploited in the Grid geometry in grid.py which can be instantiated as a hypercube using a grid_size parameter, or directly through grid locations in x.

      • LRCGeometry, low-rank cost geometries, of which a PointCloud endowed with a squared-Euclidean distance is a particular example, can efficiently carry apply their cost to another matrix. This is leveraged in particular in the low-rank Sinkhorn (and Gromov-Wasserstein) solvers.

  • In the core folder,

    • The sinkhorn function in sinkhorn.py is a wrapper around the Sinkhorn solver class, running the Sinkhorn algorithm, with the aim of solving approximately one or various optimal transport problems in parallel. An OT problem is defined by a Geometry object, and a pair (or batch thereof) of histograms. The function's outputs are stored in a SinkhornOutput named t-uple, containing potentials, regularized OT cost, sequence of errors and a convergence flag. Such outputs (with the exception of errors and convergence flag) can be differentiated w.r.t. any of the three inputs (Geometry, a, b) either through backprop or implicit differentiation of the optimality conditions of the optimal potentials f and g.

    • A later addition in sinkhorn_lr.py is focused on the LRSinkhorn solver class, which is able to solve OT problems at larger scales using an explicit factorization of couplings as being low-rank.

    • In discrete_barycenter.py: implementation of discrete Wasserstein barycenters : given histograms all supported on the same Geometry, compute a barycenter of theses measures, using an algorithm by Janati et al. (2020)

    • In gromov_wasserstein.py: implementation of two Gromov-Wasserstein solvers (both entropy-regularized and low-rank) to compare two measured-metric spaces, here encoded as a pair of Geometry objects, geom_xx, geom_xy along with weights a and b. Additional options include using a fused term by specifying geom_xy.

  • In the tools folder,

Comments
  • add first version of SCOT demonstration notebook (Gromov-Wasserstein for multi-omics)

    add first version of SCOT demonstration notebook (Gromov-Wasserstein for multi-omics)

    Adds a second biology example to the documentation notebooks.

    The notebook presents an application of OTT's Gromov Wasserstein optimal transport to match single-cell points clouds from two different measurement spaces (e.g. mapping gene expressions measurements to chromatine accessibility measurements). It is adapted from Demetci et al, Gromov-Wasserstein optimal transport to align single-cell multi-omics data, ICML 2020 Workshop on Computational Biology, 2020 (pdf here).

    opened by antoinebelloir 32
  • Pre-commit

    Pre-commit

    This PR introduces some pre-commits, some configuration files, .gitignore and a missing RST template for class (related to docs formatting in #66) Per https://google.github.io/styleguide/pyguide.html#34-indentation and because black doesn't support 2 spaces as indentation, it's now 4 spaces. I've set the max line length to 120 (given current resoution; I always use this), but this can be adapted in pyproject.toml. flake8 pre-commits are failing, I just wanted to start a discussion whether the style is ok. Afterwards, I will start fixing these.

    To try this out, run in the root of the project:

    pre-commit install  # needs to be done only once; will run `pre-commit` before commiting/pushing
    pre-commit run --all-files
    

    closes #43

    opened by michalk8 10
  • Feature/ Introduce initialization methods for Sinkhorn

    Feature/ Introduce initialization methods for Sinkhorn

    • Examples here: https://colab.research.google.com/drive/1vncmDEr3t6_OKfVC0PJin8PIRBViPyO0?usp=sharing

    • Initialization methods stored in /core/initializers.py

      • Each initializer inherits from base class SinkhornInitializer
        • Class is more flexible in case of neural network initialisers in future
        • Currently added:
          • sorting initializer (n=m) for any cost
          • Gaussian initializer for squared ground cost added, only works for point cloud as need access to x
      • Each initializer has methods init_dual_a and init_dual_b
          def init_dual_a(
                self, ot_problem: LinearProblem, lse_mode: bool = True
            ) -> jnp.ndarray:
              """Initialzation for Sinkhorn potential f.
      
      • The base class also holds default behaviour, initialization for 0 and 1 depending on log lse_mode, also handling entries for 0 weights
    • Modification to Sinkhorn api,

      • pass instantiated initializer to Sinkhorn
      Sinkhorn(
         lse_mode=lse_mode,
         threshold=threshold,
         norm_error=norm_error,
         inner_iterations=inner_iterations,
         min_iterations=min_iterations,
         max_iterations=max_iterations,
         momentum=momentum_lib.Momentum(start=chg_momentum_from, value=momentum),
         anderson=anderson,
         implicit_diff=implicit_diff,
         parallel_dual_updates=parallel_dual_updates,
         use_danskin=use_danskin,
         potential_initializer=potential_initializer,
         jit=jit
      )
      
      • Can be passed to sinkhorn functional wrapper
      gaus_init = init_lib.GaussianInitializer()
      
      @jax.jit
      def run_sinkhorn_gaus_init(x, y, a=None, b=None):
          sink_kwargs = {'jit': True, 
                      'threshold': 0.001, 
                      'max_iterations': 10**5, 
                      'potential_initializer': gaus_init}
      
          geom_kwargs = {'epsilon': 0.01}
          geom = PointCloud(x, y, **geom_kwargs)
          out = sinkhorn(geom, a=a, b=b, **sink_kwargs)
          return out
      
      • Initialization strategy will be overridden if init_dual_a/ init_dual_b vectors are passed to solver
      • Addressed this issue https://github.com/ott-jax/ott/issues/84#issue-1285013301 in tools wrapper
    enhancement 
    opened by JTT94 8
  • Functional way to run Sinkhorn algorithm via `sinkhorn.sinkhorn` seems to be significantly slower that `sinkhorn.Sinkhorn()(...)`

    Functional way to run Sinkhorn algorithm via `sinkhorn.sinkhorn` seems to be significantly slower that `sinkhorn.Sinkhorn()(...)`

    According to the different notebooks, there are two ways to run Sinkhorn, which seem to be equivalent.

    For example, suppose we have two point clouds x and y for which we want to calculate the corresponding reg_ot_cost. The two ways to compute this reg_ot_cost are:

    THRESHOLD = 1e-3
    MAX_ITERATIONS = 2_000
    EPSILON = 1e-3
    cost_fn = SqEuclidean()
    
    def get_reg_cost_1(x, y):
        "Functional."
        cost_matrix = cost_fn.all_pairs(x, y)
        geom = geometry.Geometry(cost_matrix, epsilon=EPSILON)
        return sinkhorn.sinkhorn(
            geom,
            **{"threshold": THRESHOLD,
               "max_iterations": MAX_ITERATIONS} # sinkhorn kwargs 
            ).reg_ot_cost
        
    def get_reg_cost_2(x, y):
        "Via Sinkhorn class."
        cost_matrix = cost_fn.all_pairs(x, y)
        geom = geometry.Geometry(cost_matrix, epsilon=EPSILON)
        return sinkhorn.Sinkhorn(
                threshold=THRESHOLD,
                max_iterations=MAX_ITERATIONS,
            )(linear_problem.LinearProblem(geom)).reg_ot_cost
    

    What I understand (and which may be wrong): The only thing that differs in these two ways of running the Sinkhorn algorithm, is that in the first one, we don't directly instantiate an element of the Sinkhorn class, but it will be done inside the sinkhorn.sinkhorn function, by the make function. Whereas in the second one, we instantiate an object of the Sinkhorn class and then we call this object to run the Sinkhorn algorithm. It seems that the first way is more "functional" than the second.

    Both of these ways are used in notebooks, for example the first is used in the notebook "Meta OT and Sinkhorn Initializers" https://ott-jax.readthedocs.io/en/latest/notebooks/MetaOT.html and the second is used in the notebook "Point clouds" https://ott-jax.readthedocs.io/en/latest/notebooks/point_clouds.html (which seems to have been updated recently).

    However, they give very different running times: the first seems to be significantly slower than the second.

    If I run the following code on my machine:

    # generate data
    rng = random.PRNGKey(0)
    rngs = random.split(rng, 2)
    NUM_SAMPLES, DIM = 1024, 10
    x = random.uniform(rng, minval=-1, maxval=1, shape=(NUM_SAMPLES, DIM))
    y = random.normal(rng, shape=(NUM_SAMPLES, DIM))
    
    %timeit -n100 -r3 get_reg_cost_1(x, y).block_until_ready()
    %timeit -n100 -r3 get_reg_cost_2(x, y)
    

    I get that get_reg_cost_1 is more than 70 times slower that get_reg_cost_2:

    720 ms ± 4.1 ms per loop (mean ± std. dev. of 3 runs, 100 loops each)
    10 ms ± 319 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)
    

    So I guess I'm making a mistake somewhere in the use of solvers? Or if not, how can we explain this difference?

    Thank you very much!

    opened by theouscidda6 7
  • Meta OT Initializer

    Meta OT Initializer

    This is an initial version of a Meta OT initializer along with a notebook demonstrating it's usage. Here is how the notebook renders in the website, and here are how the code docs render:

    image image image

    enhancement 
    opened by bamos 7
  • Fix: switch in the source and target distribution allocation of the potentials

    Fix: switch in the source and target distribution allocation of the potentials

    Quick edit: I believe my understanding was flawed, sorry. It could be correct as it currently is. I (may wrongly) assumed Y would denote the target distribution and we were looking for the map from X to Y (hence, source to target). However, it seems that Y is sampled from Q, which is stated as the source distribution in the paper (and not only in the pseudocode, but also at other locations). The goal is to learn the map from Q to P. I realized this after having a more detailed look at the terminology used there again. Now that I have thought about it, my assessment seems to have been overhasty. I would still be interested if someone could clarify my understanding, however, so I get it actually correct. :-)


    Perhaps my understanding is also wrong (which I cannot rule out).

    According to the original paper, and my understanding, the g potential should be applied to the target distribution and the f potential to the source (see, e.g., Eq. 5 or Eq. 9). In the current version, this is switched.

    I suspect this mix-up comes from the fact that Algorithm 1 in the paper has the source and target also mixed up in the pseudocode. Here, the source distribution is indeed given to the g potential and the target distribution to f. However, following the equations, I think this is simply a mistake (although I am not 100% sure).

    My suspicion is further backed by the original implementation of the authors, where the source samples are given to the f potential and the target to the g.

    After changing this accordingly, I feel that this improved the convergence speed as well (from a few first test runs).

    opened by ppnaumann 6
  • custom_vjp seems to not be used during computation of higher order derivatives of Sinkhorn

    custom_vjp seems to not be used during computation of higher order derivatives of Sinkhorn

    Hello,

    It seems that the custom_vjp (defined with either _iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd) in sinkhorn.py for the implicit differentiation or fixpoint_iter_backprop.defvjp(fixpoint_iter_fwd, fixpoint_iter_bwd) in fixed_point_loop for the unrolling) does not seem to be used for the computation of higher order derivatives. I believe that this is also related to the error discussed previously in this issue. This is not evident with the computation of jax.hessian (which is implemented as jax.jacfwd(jax.jacrev), as the second derivation uses forward mode automatic differentiation, which is compatible with the jax.lax.while_loop. Therefore, no error is raised even if the custom_vjp is ignored the second time.

    I think that this can be seen by adding a breakpoint in the _while_loop_jvp of JAX's control_flow.py. For instance, computation of jax.hessian passes through _while_loop_jvp first, then one time throughfixpoint_iter for implicit diff (or fixpoint_iter_fwd for unrolling) and then one time through _iterations_implicit_bwd for implicit diff (or fixpoint_iter_bwd for unrolling). In the case where a jax.lax.scan is forced by setting min_iterations equal to max_iterations (and both Jacobians are computed with reverse mode), instead of the initial pass through _while_loop_jvp, one gets in the end a pass through _scan_transpose of JAX's control_flow.py.

    In either case, I think that if the custom_vjp was not ignored during rederivation, one should get two passes through _iterations_implicit_bwd (equivalently fixpoint_iter_bwd), right?

    I am not sure how this could be fixed. The only relevant info that I could find in JAX's documentation for custom_vjp was this:

    "Notice that f_jvp calls f to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original f to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can't make use of intermediate values from the evaluation of f in our rule and also have the rule apply in all orders of higher-order differentiation.)"

    I hope I am not missing something. If there is a solution to this, it would be great, as what I need to compute is jax.jacrev(jax.grad) for the sinkhorn divergence and the forced scan option causes a significant computational overhead, especially for the two autocorrelation terms Pxx, Pyy.

    Many thanks!

    opened by ersisimou 6
  • handling padding when computing scaling for cost matrices.

    handling padding when computing scaling for cost matrices.

    The segmented approaches to pad efficiently point clouds rely on padding with "default" (e.g. zero) vectors. These padded vectors create spurious entries in cost matrices, but when running Sinkhorn, these entries are ignored thanks to the fact that these entries have 0 weights (in the respective a and b weight vectors). However, these entries would play a role when computing scalings (such as mean cost).

    A possible approach would be to redefine some of those statistics (e.g. max, mean etc..) to only apply to entries that have a positive weight (e.g. C_ij is considered iff a_i b_j >0). For the mean cost, the mean might be weighted by a_i b_j as well.

    enhancement 
    opened by marcocuturi 5
  • application of `low-rank` to single-cell data

    application of `low-rank` to single-cell data

    TL;DR in this colab we provide an example for our failure in obtaining a valid mapping using low-rank.

    problem setup: In this example data set we look into mapping spatial transcriptomics at single cell resolution from mouse embryonic tissues across two time-points. so the quadratic term accounts for distances in spatial coordinates and the linear captures distances in gene-expression.

    evaluation of the mapping: As an initial sanity check we look at the cell-transition table, that is the transition matrix with entries grouped by cell types (we asses the row-stochastic, forward, setting). the naive assumption is that cells of the same type, e.g. brain will be mapped mainly mapped to themselves. Evaluating the regular FGW and FGW (unbalanced) this is indeed what we observe. However, for low-rank we get a matrix with constant columns. We observed a similar phenomena at different time-points. Comparing the results we can see hints for the constant columns as they are cell-types also favored in the regular regime.

    image image image

    bug 
    opened by zoepiran 5
  • catching `online=True` in `gromov_wasserstein`

    catching `online=True` in `gromov_wasserstein`

    hi, we noticed that one can accidentally pass online=True using gromov_wasserstein and this will propagate through **kwargs. Eventually mapping will fail at some point due to dimensionality assertion but we thought it may be useful to catch it earlier and warn. What do you think?

    opened by zoepiran 5
  • Enable to set parameters for inner Sinkhorn loop in Gromov-Wasserstein

    Enable to set parameters for inner Sinkhorn loop in Gromov-Wasserstein

    Currently, it is not possible to set parameters like max_iterations and min_iterations of the inner Sinkhorn loop when calling WassersteinSolver or GromovWasserstein because of the same argument names.

    It would be great to rename the outer loop parameters to have more control about the algorithm.

    opened by MUCDK 4
  • [WIP] Improvements to the neural dual solver

    [WIP] Improvements to the neural dual solver

    Hi @marcocuturi @bunnech @michalk8, here is a WIP PR with the updates we've discussed to the neural dual solver. It improves the example from:

    image

    to this:

    image

    My paper has some more context on these updates adding some more options for fine-tuning and amortizing the conjugate approximation. I've marked this as a WIP because I'm still debugging the inverse map in the example, but otherwise, the rest of the code is ready for an initial review as it would be good to get your thoughts on the core updates and API/doc changes. I can also call and chat a little more about these soon if that would be helpful. Here are quick summaries of my updates to the code and notebook.

    Updates to the core neural solver

    1. swap the order of f and g potentials (\cc https://github.com/ott-jax/ott/issues/182#issuecomment-1344895638)
    2. add an option to add a conjugate solver that fine-tunes g's prediction when updating f, along with a default solver using JaxOpt's LBFGS optimizer
    3. add amortization modes to learn g: regression and objective (Makkuva et al.)
    4. an option to update both potentials with the same batch (in parallel) in addition to the separate inner loop Makkuva et al. uses
    5. the option to model the gradient of the g potential with a neural network and a simple MLP implementing this instead of the ICNN from Makkuva et al.
    6. a callback so the training progress can be monitored (debugging the notebook/code was very difficult without this)

    Updates to the notebook/example

    (I'm still tweaking these around a little as it's still a little suboptimal)

    1. Update f ICNN size from [64 ,64, 64, 64] to [128, 128] and modified the architecture
    2. Use an MLP to model the gradient of g
    3. Use the default JaxOpt LBFGS conjugate solver and update g to regress onto it
    4. Increase batch size from 1000 to 10000
    5. Sample the data in batches rather than sequentially
    6. Tweak Adam (use weight decay and a cosine schedule)

    I've also tried running the code with an approach closer to the older implementation, i.e., the approach from Makkuva et al. that models g with an ICNN and uses 10 inner objective-based updates, but can't get it to work as nicely as the version with conjugate fine-tuning+regression.

    Remaining discussion and TODOs

    • [ ] Decide on best format for the arguments. E.g., should the conjugate solver be a function and the amortization mode a string?
    • [ ] Decide on the best defaults to use (a lightweight LBFGS conjugate solver for fine-tuning that regresses g onto it, or something closer to the approach in Makkuva et al.?)
    • [ ] Improve the documentation and notebook to clarify these options
    • [ ] Fix the inverse map in the notebook
    • [ ] Think about any tests that would be useful to add for these updates
    • [ ] Update other places using the Neural OT code (like the initialization schemes notebook)
    • [ ] Decide on the default ICNN architecture
    • [ ] Add ability to save/load the parameters and ability to re-initialize the solves with them
    enhancement bugfix 
    opened by bamos 2
  • Non-jitting OOM issues

    Non-jitting OOM issues

    Without jitting, some properties as SinkhornOutput.primal_cost (using online point cloud geometry), as well as functions DualPotentials.Transport or k_means cause OOM issues. Previously, jitting was done by default, see #192. Question is whether to re-introduce the jit argument to solvers (would be again True by default), always jit the classes or try another solution.

    opened by michalk8 0
  • Make threshold in Sinkhorn algorithms (and more generally linear solvers) depend on dimension and `norm_error` used, so that is becomes relative.

    Make threshold in Sinkhorn algorithms (and more generally linear solvers) depend on dimension and `norm_error` used, so that is becomes relative.

    in one way or another...

    Use case: when sinkhorn is run, the marginal deviation is used as a stopping criterion.

    By default the 1-norm is used: https://github.com/ott-jax/ott/blob/3ebac6369acfe848930e1e2abb69b61ecea7e5cc/src/ott/solvers/linear/sinkhorn.py#L455

    and fed into error computation. https://github.com/ott-jax/ott/blob/3ebac6369acfe848930e1e2abb69b61ecea7e5cc/src/ott/solvers/linear/sinkhorn.py#L209

    Typically, this involves comparing two marginal (probability) vectors of size n or m. Because those vectors are constrained to lie in the simplex of that dimension, the tolerance (what constitutes a small deviation) should scale graciously (if needed with n or m), and depend on the norm that is used.

    Also, an alternative would be to add the Hellinger distance as a norm to control Sinkhorn...

    opened by marcocuturi 0
  • Progress report

    Progress report

    Hi, for computations than can take several hours, it would be great if OTT-jax reported progress. I can see at least two types of progress reports:

    1. In some computations, the number of steps is known in advance, for instance if using the optional batch_size parameter of PointCloud and using Sinkhorn or LRSinkhorn: each Sinkhorn iteration is a fixed set of submatrices operations.

    2. In some other cases the number of steps is indeterminate because we're waiting for some convergence criterion. But OTT-jax could in these cases report the number of iterations performed so far and the value associated with the convergence criterion, same as you would report a loss in a training loop.

    Because reporting progress is a side effect, jax has introduced the host_callback module. I think it's worth a try.

    I don't think OTT-jax should depend on progress report libs; it would be instead the responsibility of the user to provide their own callback function, such as tqdm.

    Would you be open to this?

    enhancement 
    opened by bosr 1
  • Offer a way to extract a few rows of potentially large cost matrices

    Offer a way to extract a few rows of potentially large cost matrices

    For some transport problems, when searching for nearest neighbors, it is useful to analyze parts of the cost matrix. Until now, version 0.3.1, the API only offers a way to materialize and return the entire cost matrix, which is sometimes prohibitively expensive, or impossible (OOM).

    I would like to API to offer a way to return only rows and columns of the cost matrix, or the documentation show how to do that.

    The good news is that, for problems solved with Sinkhorn, extracting a few rows of the cost matrix is simple thanks to jax:

    geom = pointcloud.PointCloud(...)  # n points in the input domain
    lin_prob = linear_problem.LinearProblem(...)
    out = sinkhorn.Sinkhorn()(lin_prob)
    
    # method 1 - full materialization then extract row k
    # C = geom.cost_matrix
    # C_k = C[k, :]
    
    # alternative - extract row k
    u = jnp.zeros((1,n)).at[k].set(1.0)
    C_k = out.apply(u)
    

    Would it make sense for you?

    opened by bosr 6
  • Running `segment_sinkhorn` with different `sinkhorn_kwargs` for each segment

    Running `segment_sinkhorn` with different `sinkhorn_kwargs` for each segment

    Is it possible to run segment_sinkhorn with different sinkhorn_kwargs for each segment?

    For example, with two PoinClouds x and y, each with two segments: (x[0:1024], y[0:1024]) and (x[1024:2048], y[1024: 2048]) can segment_sinkhorn be used to run Sikhorn in parallel on (x[0:1024], y[0:1024]) and (x[1024:2048], y[1024:2048]) but with different sinkhorn_kwargs?

    If this is not possible, would you consider adding this feature?

    In particular, it would be useful to compute the sinkhorn_divergence with segment_sinkhorn while keeping the benefits of the particular sikhorn_kwargs to speed up the computation of symmetric terms.

    More precisely, the computation of the sinkhorn_divergence between two PoinClouds x and y involves the computation of 3 transports, namely between the 3 segments: (x,y), (x,x) and (y,y). With segment_sinkhorn one could compute these 3 transports in parallel. On the other hand, the terms induced by the symmetrical segments (x,x) and (y,y) can be computed faster using sinkhorn_kwargs={parallel_dual_updates=True, momentum=0.5, chg_momentum_from=0, anderson_acceleration=0}. So to be optimal, we should be able to run segment_sinkhorn with these particular sinkhorn_kwargs for the segments (x,x) and (y,y) but not for the segment (x,y), hence the usefulness of being able to use different sinkhorn_kwargs.

    Thank you very much!

    enhancement 
    opened by theouscidda6 2
Releases(0.3.1)
  • 0.3.1(Dec 15, 2022)

    What's Changed

    • Fix jax.device_put for potentials by @michalk8 in https://github.com/ott-jax/ott/pull/183
    • Introduce primal_cost and dual_cost for Sinkhorn outputs (only primal for LR) for arbitrary geometries. by @marcocuturi in https://github.com/ott-jax/ott/pull/184
    • Feature/src layout by @michalk8 in https://github.com/ott-jax/ott/pull/188
    • Handle better inf with p parameter in SqPNorm by @marcocuturi in https://github.com/ott-jax/ott/pull/189
    • Feature/graph normalized laplacian by @michalk8 in https://github.com/ott-jax/ott/pull/191
    • Update installation instruction by @michalk8 in https://github.com/ott-jax/ott/pull/195
    • Feature/center potentials by @michalk8 in https://github.com/ott-jax/ott/pull/194
    • Remove jit from solvers by @michalk8 in https://github.com/ott-jax/ott/pull/192
    • Accelerating unbalanced OT by @michalk8 in https://github.com/ott-jax/ott/pull/197
    • Feature/depracate functional api by @michalk8 in https://github.com/ott-jax/ott/pull/204
    • Fixes numerical errors in Bures barycenter, and sqrtm, due to low default precision. by @marcocuturi in https://github.com/ott-jax/ott/pull/205
    • Update docstrings for deprecated functions by @michalk8 in https://github.com/ott-jax/ott/pull/207

    Full Changelog: https://github.com/ott-jax/ott/compare/0.3.0...0.3.1

    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(Nov 23, 2022)

    The main changes in this version are twofold:

    • Changes in PointCloud geometry, and more specifically handling of cost function. The power parameter that was used to optionally pass a c(x,y) = CostFn(x,y) ** power is now deprecated. One can add it manually by defining a custom CostFn. To compensate this change, a new class of translation invariant costs (TICost) has been created, from which most costs now inherit, defined as c(x,y) = h(x-y). Additionally, to handle Brenier's theorem, the user has the option of passing on the legendre transform of h, h_legendre.
    • The core folder was too horizontal, containing various modules. It has been reorganized and split into 3 modules that make more sense, problems (to describe OT problems), solvers (solve them) and initializers (as optional modules to help solvers). The latter two have an arborescence that reflects that in problems.

    What's Changed

    • fix point_clouds.ipynb by @marcocuturi in https://github.com/ott-jax/ott/pull/159
    • Deprecate power in PointCloud, introduce TICost and use it to compute Entropic (Brenier) maps. by @marcocuturi in https://github.com/ott-jax/ott/pull/167
    • Misc/project structure by @michalk8 in https://github.com/ott-jax/ott/pull/176

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.11...0.3.0

    Source code(tar.gz)
    Source code(zip)
  • 0.2.11(Oct 19, 2022)

    What's Changed

    • default momentum with start=300 by @marcocuturi in https://github.com/ott-jax/ott/pull/155
    • Unbalanced gromov by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/128
    • Euclidean / SqEuclidean and changes in power by @marcocuturi in https://github.com/ott-jax/ott/pull/157

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.10...0.2.11

    Source code(tar.gz)
    Source code(zip)
  • 0.2.10(Oct 14, 2022)

    What's Changed

    • Make GMM documentation more visible, fix bug in M step of EM algorithm by @geoff-davis in https://github.com/ott-jax/ott/pull/144
    • Feature/initializers as literals by @michalk8 in https://github.com/ott-jax/ott/pull/148
    • Fix fused_penalty and scale_cost in LRGW by @michalk8 in https://github.com/ott-jax/ott/pull/147
    • Add dtype property by @michalk8 in https://github.com/ott-jax/ott/pull/150
    • Meta OT Initializer by @bamos in https://github.com/ott-jax/ott/pull/145
    • fix power=1.0 using abs on cost values when needed by @marcocuturi in https://github.com/ott-jax/ott/pull/153

    New Contributors

    • @geoff-davis made their first contribution in https://github.com/ott-jax/ott/pull/144

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.9...0.2.10

    Source code(tar.gz)
    Source code(zip)
  • 0.2.9(Oct 3, 2022)

    What's Changed

    • Feature/kmeans++ by @michalk8 in https://github.com/ott-jax/ott/pull/120
    • Feature/ Introduce initialization methods for Sinkhorn by @JTT94 in https://github.com/ott-jax/ott/pull/98
    • Pin sphinx-book-theme>=0.3.3 by @michalk8 in https://github.com/ott-jax/ott/pull/123
    • Additional Docstrings and Comments to NeuralDual-Module by @bunnech in https://github.com/ott-jax/ott/pull/122
    • Allow k-means to be differentiable when using while loop by @michalk8 in https://github.com/ott-jax/ott/pull/130
    • Feature/graph geometry by @michalk8 in https://github.com/ott-jax/ott/pull/126
    • LR Sinkhorn improvements by @meyerscetbon in https://github.com/ott-jax/ott/pull/111
    • Refactor GW initialization by @michalk8 in https://github.com/ott-jax/ott/pull/133
    • Change LR initializer to random by @michalk8 in https://github.com/ott-jax/ott/pull/134
    • Generalized k-means initialization for LR Sinkhorn/Gromov-Wasserstein by @michalk8 in https://github.com/ott-jax/ott/pull/135
    • Switch to one linear system per implicit differentiation by @ersisimou in https://github.com/ott-jax/ott/pull/136
    • Feature/disable warm start by @michalk8 in https://github.com/ott-jax/ott/pull/138
    • Add entropic map by @michalk8 in https://github.com/ott-jax/ott/pull/142

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.8...0.2.9

    Source code(tar.gz)
    Source code(zip)
  • 0.2.8(Aug 3, 2022)

    What's Changed

    • Refactoring problems.py by @marcocuturi in https://github.com/ott-jax/ott/pull/86
    • Feature/gw losses by @michalk8 in https://github.com/ott-jax/ott/pull/88
    • Fix/lr linear apply by @michalk8 in https://github.com/ott-jax/ott/pull/91
    • Addition of ICNN Initialization Schemes by @bunnech in https://github.com/ott-jax/ott/pull/90
    • Fix/point cloud apply cost fn by @michalk8 in https://github.com/ott-jax/ott/pull/93
    • Feature/lr gw apply output by @michalk8 in https://github.com/ott-jax/ott/pull/96
    • Remove legacy online option by @michalk8 in https://github.com/ott-jax/ott/pull/97
    • Feature/fast tests by @michalk8 in https://github.com/ott-jax/ott/pull/101
    • Fix/axis norm by @michalk8 in https://github.com/ott-jax/ott/pull/103
    • Generic LR cost decomposition by @michalk8 in https://github.com/ott-jax/ott/pull/99
    • Remove pytest-memray from test req, use only in CI by @michalk8 in https://github.com/ott-jax/ott/pull/107
    • demo and fixes for barycenters of GMMs by @ersisimou in https://github.com/ott-jax/ott/pull/89
    • Add (F)GW barycenters by @michalk8 in https://github.com/ott-jax/ott/pull/87
    • Feature/notebook tests by @michalk8 in https://github.com/ott-jax/ott/pull/108
    • [ci skip] Fix duplicate bibtex entry and a typo by @michalk8 in https://github.com/ott-jax/ott/pull/109
    • GW different cost ranks by @michalk8 in https://github.com/ott-jax/ott/pull/113
    • segment_sinkhorn by @marcocuturi in https://github.com/ott-jax/ott/pull/114
    • Create CODE_OF_CONDUCT.md by @marcocuturi in https://github.com/ott-jax/ott/pull/118
    • Fix/padded scaling by @michalk8 in https://github.com/ott-jax/ott/pull/116

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.7...0.2.8

    Source code(tar.gz)
    Source code(zip)
  • 0.2.7(Jun 22, 2022)

    What's Changed

    • Pre-commit by @michalk8 in https://github.com/ott-jax/ott/pull/71
    • Fix/online gw by @michalk8 in https://github.com/ott-jax/ott/pull/80
    • Feature/improve packaging by @michalk8 in https://github.com/ott-jax/ott/pull/82
    • Ersi gauss bary by @ersisimou in https://github.com/ott-jax/ott/pull/81

    New Contributors

    • @ersisimou made their first contribution in https://github.com/ott-jax/ott/pull/81

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.6...0.2.7

    Source code(tar.gz)
    Source code(zip)
  • 0.2.6(May 24, 2022)

    What's Changed

    • Fix jit issue with relative epsilon and copying epsilons by @JTT94 in https://github.com/ott-jax/ott/pull/52
    • Bug/scale cost by @MUCDK in https://github.com/ott-jax/ott/pull/54
    • ensure 3 geom objects returned when copying epsilons by @JTT94 in https://github.com/ott-jax/ott/pull/56
    • fixed logic of __call__ of LRSinkhorn to prevent uninterpretable error by @MUCDK in https://github.com/ott-jax/ott/pull/57
    • Project import generated by Copybara. by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/58
    • General cost fn by @JTT94 in https://github.com/ott-jax/ott/pull/61
    • Set converged logic by @zoepiran in https://github.com/ott-jax/ott/pull/62
    • Fix LRSinkhornOutput.apply with batches by @michalk8 in https://github.com/ott-jax/ott/pull/63
    • Fix is_all_geoms_lr not returning by @michalk8 in https://github.com/ott-jax/ott/pull/67
    • Feature/quadratic problem scale by @michalk8 in https://github.com/ott-jax/ott/pull/66
    • Fix LR tests, fused_penalty property, tolerances by @michalk8 in https://github.com/ott-jax/ott/pull/68
    • add first version of SCOT demonstration notebook (Gromov-Wasserstein for multi-omics) by @antoinebelloir in https://github.com/ott-jax/ott/pull/64
    • Fix LR Gromov memory by @michalk8 in https://github.com/ott-jax/ott/pull/70
    • Fix scale_cost to LRGeometry, float/max-bound by @michalk8 in https://github.com/ott-jax/ott/pull/72
    • Update README to use Markdown math by @adrhill in https://github.com/ott-jax/ott/pull/75
    • Fix neural dual notebook by @lucaeyring in https://github.com/ott-jax/ott/pull/76

    New Contributors

    • @JTT94 made their first contribution in https://github.com/ott-jax/ott/pull/52
    • @MUCDK made their first contribution in https://github.com/ott-jax/ott/pull/54
    • @antoinebelloir made their first contribution in https://github.com/ott-jax/ott/pull/64
    • @adrhill made their first contribution in https://github.com/ott-jax/ott/pull/75
    • @lucaeyring made their first contribution in https://github.com/ott-jax/ott/pull/76

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.5...0.2.6

    Source code(tar.gz)
    Source code(zip)
  • 0.2.5(Apr 2, 2022)

    What's Changed

    • not converged logic by @zoepiran in https://github.com/ott-jax/ott/pull/50

    New Contributors

    • @zoepiran made their first contribution in https://github.com/ott-jax/ott/pull/50

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.4...0.2.5

    Source code(tar.gz)
    Source code(zip)
  • 0.2.4(Mar 30, 2022)

    What's Changed

    • Fixes rank and epsilon arguments for LR-GW (see issue #21) by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/28
    • Add Neural Dual Solver and Utilities by @bunnech in https://github.com/ott-jax/ott/pull/32
    • Migrate away from using JaxTestCase in tests by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/35
    • Adding scaling factor to the cost matrix. by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/38
    • Initialization LRSinkhorn by @meyerscetbon in https://github.com/ott-jax/ott/pull/47
    • Adding max_cost scaling for LR. by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/48
    • Free support Wasserstein barycenter by @marcocuturi in https://github.com/ott-jax/ott/pull/49

    New Contributors

    • @bunnech made their first contribution in https://github.com/ott-jax/ott/pull/32
    • @meyerscetbon made their first contribution in https://github.com/ott-jax/ott/pull/39

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.3...0.2.4

    Source code(tar.gz)
    Source code(zip)
  • 0.2.3(Mar 2, 2022)

    What's Changed

    • Update direct transport plan computation by @theouscidda6 in https://github.com/ott-jax/ott/pull/5
    • Updating GW notebook to work with v0.2.2, adding unbalanced and fused GW options to Transport. by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/4
    • Initialising linear_convergence and costs in GWState to -1 by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/6
    • README example: transport.Transport -> transport.solve by @bamos in https://github.com/ott-jax/ott/pull/7
    • Fixes bug in transport.solve by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/8
    • Remove custom vdot implementation by @michalk8 in https://github.com/ott-jax/ott/pull/10
    • Fixes imports in tests. by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/11
    • Fixes tests. by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/12
    • Fixes tests. by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/13
    • Fixing tests. by @LaetitiaPapaxanthos in https://github.com/ott-jax/ott/pull/14
    • Entropic regularization for LR-GW by @marcocuturi in https://github.com/ott-jax/ott/pull/19
    • Batch apply_lse_kernel for online=True by @michalk8 in https://github.com/ott-jax/ott/pull/23

    New Contributors

    • @theouscidda6 made their first contribution in https://github.com/ott-jax/ott/pull/5
    • @LaetitiaPapaxanthos made their first contribution in https://github.com/ott-jax/ott/pull/4
    • @bamos made their first contribution in https://github.com/ott-jax/ott/pull/7
    • @michalk8 made their first contribution in https://github.com/ott-jax/ott/pull/10

    Full Changelog: https://github.com/ott-jax/ott/compare/0.2.2...0.2.3

    Source code(tar.gz)
    Source code(zip)
  • 0.2.2(Feb 3, 2022)

  • 0.2.1(Feb 3, 2022)

  • 0.2.0(Jan 24, 2022)

    Version 2.0 is out with major changes:

    • Toolbox will now be developped in the ott-jax org's repo, and has moved out of google_research. pip path remains the same, i.e. pip install ott-jax.
    • Introduction of LRCGeometry, i.e. low rank geometries that can apply cost in efficient manner.
    • Rewritten the sinkhorn function, which is now a wrapper around a Sinkhorn solver class, using SinkhornState and SinkhornOutput variables. In this refactoring, momentum and anderson are now objects.
    • Same for gromov_wasserstein which now runs solver GromovWasserstein
    • OT problems are now defined in problems and quad_problems
    • Addition of low-rank solvers, both to solve Sinkhorn (LRSinkhorn) and to solve Gromov-Wasserstein (just need to specify a rank).
    • Unbalanced GW solver
    • Tools repository now holds tools to define and manipulate Gaussians and mixtures of Gaussians.
    • and many more...
    Source code(tar.gz)
    Source code(zip)
Owner
OTT-JAX
OTT-JAX
A versatile token stream for handwritten parsers.

Writing recursive-descent parsers by hand can be quite elegant but it's often a bit more verbose than expected, especially when it comes to handling indentation and reporting proper syntax errors. Th

Valentin Berlier 8 Nov 30, 2022
KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark.

KLUE Baseline Korean(한국어) KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark. See our paper fo

74 Dec 13, 2022
构建一个多源(公众号、RSS)、干净、个性化的阅读环境

2C 构建一个多源(公众号、RSS)、干净、个性化的阅读环境 作为一名微信公众号的重度用户,公众号一直被我设为汲取知识的地方。随着使用程度的增加,相信大家或多或少会有一个比较头疼的问题——广告问题。 假设你关注的公众号有十来个,若一个公众号两周接一次广告,理论上你会面临二十多次广告,实际上会更多,运

howie.hu 678 Dec 28, 2022
Shared code for training sentence embeddings with Flax / JAX

flax-sentence-embeddings This repository will be used to share code for the Flax / JAX community event to train sentence embeddings on 1B+ training pa

Nils Reimers 23 Dec 30, 2022
GooAQ 🥑 : Google Answers to Google Questions!

This repository contains the code/data accompanying our recent work on long-form question answering.

AI2 112 Nov 06, 2022
A PyTorch implementation of paper "Learning Shared Semantic Space for Speech-to-Text Translation", ACL (Findings) 2021

Chimera: Learning Shared Semantic Space for Speech-to-Text Translation This is a Pytorch implementation for the "Chimera" paper Learning Shared Semant

Chi Han 43 Dec 28, 2022
ADCS - Automatic Defect Classification System (ADCS) for SSMC

Table of Contents Table of Contents ADCS Overview Summary Operator's Guide Demo System Design System Logic Training Mode Production System Flow Folder

Tam Zher Min 2 Jun 24, 2022
English loanwords in the world's languages

Wiktionary as CLDF Content cldf1 and cldf2 contain cldf-conform data sets with a total of 2 377 756 entries about the vocabulary of all 1403 languages

Viktor Martinović 3 Jan 14, 2022
Extract rooms type, door, neibour rooms, rooms corners nad bounding boxes, and generate graph from rplan dataset

Housegan-data-reader House-GAN++ (data-reader) Code and instructions for converting rplan dataset (raster images) to housegan++ data format. House-GAN

Sepid Hosseini 13 Nov 24, 2022
Learning Spatio-Temporal Transformer for Visual Tracking

STARK The official implementation of the paper Learning Spatio-Temporal Transformer for Visual Tracking Highlights The strongest performances Tracker

Multimedia Research 485 Jan 04, 2023
A Non-Autoregressive Transformer based TTS, supporting a family of SOTA transformers with supervised and unsupervised duration modelings. This project grows with the research community, aiming to achieve the ultimate TTS.

A Non-Autoregressive Transformer based TTS, supporting a family of SOTA transformers with supervised and unsupervised duration modelings. This project grows with the research community, aiming to ach

Keon Lee 237 Jan 02, 2023
Code Generation using a large neural network called GPT-J

CodeGenX is a Code Generation system powered by Artificial Intelligence! It is delivered to you in the form of a Visual Studio Code Extension and is Free and Open-source!

DeepGenX 389 Dec 31, 2022
A PyTorch implementation of VIOLET

VIOLET: End-to-End Video-Language Transformers with Masked Visual-token Modeling A PyTorch implementation of VIOLET Overview VIOLET is an implementati

Tsu-Jui Fu 119 Dec 30, 2022
WIT (Wikipedia-based Image Text) Dataset is a large multimodal multilingual dataset comprising 37M+ image-text sets with 11M+ unique images across 100+ languages.

WIT (Wikipedia-based Image Text) Dataset is a large multimodal multilingual dataset comprising 37M+ image-text sets with 11M+ unique images across 100+ languages.

Google Research Datasets 740 Dec 24, 2022
PyTorch implementation of the paper: Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding

Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding This repository contains the official PyTorch implementation of th

Xiao Xu 26 Dec 14, 2022
Basic Utilities for PyTorch Natural Language Processing (NLP)

Basic Utilities for PyTorch Natural Language Processing (NLP) PyTorch-NLP, or torchnlp for short, is a library of basic utilities for PyTorch NLP. tor

Michael Petrochuk 2.1k Jan 01, 2023
An open-source NLP library: fast text cleaning and preprocessing.

An open-source NLP library: fast text cleaning and preprocessing

Iaroslav 21 Mar 18, 2022
The projects lets you extract glossary words and their definitions from a given piece of text automatically using NLP techniques

Unsupervised technique to Glossary and Definition Extraction Code Files GPT2-DefinitionModel.ipynb - GPT-2 model for definition generation. Data_Gener

Prakhar Mishra 28 May 25, 2021
Python library for processing Chinese text

SnowNLP: Simplified Chinese Text Processing SnowNLP是一个python写的类库,可以方便的处理中文文本内容,是受到了TextBlob的启发而写的,由于现在大部分的自然语言处理库基本都是针对英文的,于是写了一个方便处理中文的类库,并且和TextBlob

Rui Wang 6k Jan 02, 2023
Smart discord chatbot integrated with Dialogflow to manage different classrooms and assist in teaching!

smart-school-chatbot Smart discord chatbot integrated with Dialogflow to interact with students naturally and manage different classes in a school. De

Tom Huynh 5 Oct 24, 2022