From 590e4d714f73d6596cd14614c93b1c15e7426c51 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 11:43:10 +0100 Subject: [PATCH 01/23] ot.lp reorganise to avoid def in __init__ --- CONTRIBUTORS.md | 2 +- RELEASES.md | 2 + ot/lp/__init__.py | 876 +-------------------------------------- ot/lp/barycenter.py | 266 ++++++++++++ ot/lp/network_simplex.py | 612 +++++++++++++++++++++++++++ 5 files changed, 887 insertions(+), 871 deletions(-) create mode 100644 ot/lp/barycenter.py create mode 100644 ot/lp/network_simplex.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 39f0b23d4..6f6a72737 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -48,7 +48,7 @@ The contributors to this library are: * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) -* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) +* [Clément Bonet](https://clbonet.github.io) (Wasserstein on circle, Spherical Sliced-Wasserstein) * [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) diff --git a/RELEASES.md b/RELEASES.md index 0ddac599b..e29be544e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,8 @@ - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) +- Implement fixed-point solver for OT barycenters with generic cost functions + (generalizes `ot.lp.free_support_barycenter`). (PR #???) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2b93e84f3..d11a5ee41 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -8,15 +8,17 @@ # # License: MIT License -import numpy as np -import warnings - from . import cvx from .cvx import barycenter from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize +from .network_simplex import emd, emd2 +from .barycenter import ( + free_support_barycenter, + generalized_free_support_barycenter +) # import compiled emd -from .emd_wrap import emd_c, check_result, emd_1d_sorted +from .emd_wrap import emd_1d_sorted from .solver_1d import ( emd_1d, emd2_1d, @@ -26,9 +28,6 @@ semidiscrete_wasserstein2_unif_circle, ) -from ..utils import dist, list_to_array -from ..backend import get_backend - __all__ = [ "emd", "emd2", @@ -46,866 +45,3 @@ "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", ] - - -def check_number_threads(numThreads): - """Checks whether or not the requested number of threads has a valid value. - - Parameters - ---------- - numThreads : int or str - The requested number of threads, should either be a strictly positive integer or "max" or None - - Returns - ------- - numThreads : int - Corrected number of threads - """ - if (numThreads is None) or ( - isinstance(numThreads, str) and numThreads.lower() == "max" - ): - return -1 - if (not isinstance(numThreads, int)) or numThreads < 1: - raise ValueError( - 'numThreads should either be "max" or a strictly positive integer' - ) - return numThreads - - -def center_ot_dual(alpha0, beta0, a=None, b=None): - r"""Center dual OT potentials w.r.t. their weights - - The main idea of this function is to find unique dual potentials - that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having - stability when multiple calling of the OT solver with small changes. - - Basically we add another constraint to the potential that will not - change the objective value but will ensure unicity. The constraint - is the following: - - .. math:: - \alpha^T \mathbf{a} = \beta^T \mathbf{b} - - in addition to the OT problem constraints. - - since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing - a constant from both :math:`\alpha_0` and :math:`\beta_0`. - - .. math:: - c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}} - - \alpha &= \alpha_0 + c - - \beta &= \beta_0 + c - - Parameters - ---------- - alpha0 : (ns,) numpy.ndarray, float64 - Source dual potential - beta0 : (nt,) numpy.ndarray, float64 - Target dual potential - a : (ns,) numpy.ndarray, float64 - Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 - Target histogram (uniform weight if empty list) - - Returns - ------- - alpha : (ns,) numpy.ndarray, float64 - Source centered dual potential - beta : (nt,) numpy.ndarray, float64 - Target centered dual potential - - """ - # if no weights are provided, use uniform - if a is None: - a = np.ones(alpha0.shape[0]) / alpha0.shape[0] - if b is None: - b = np.ones(beta0.shape[0]) / beta0.shape[0] - - # compute constant that balances the weighted sums of the duals - c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) - - # update duals - alpha = alpha0 + c - beta = beta0 - c - - return alpha, beta - - -def estimate_dual_null_weights(alpha0, beta0, a, b, M): - r"""Estimate feasible values for 0-weighted dual potentials - - The feasible values are computed efficiently but rather coarsely. - - .. warning:: - This function is necessary because the C++ solver in `emd_c` - discards all samples in the distributions with - zeros weights. This means that while the primal variable (transport - matrix) is exact, the solver only returns feasible dual potentials - on the samples with weights different from zero. - - First we compute the constraints violations: - - .. math:: - \mathbf{V} = \alpha + \beta^T - \mathbf{M} - - Next we compute the max amount of violation per row (:math:`\alpha`) and - columns (:math:`beta`) - - .. math:: - \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j} - - \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j} - - Finally we update the dual potential with 0 weights if a - constraint is violated - - .. math:: - \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0 - - \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0 - - In the end the dual potentials are centered using function - :py:func:`ot.lp.center_ot_dual`. - - Note that all those updates do not change the objective value of the - solution but provide dual potentials that do not violate the constraints. - - Parameters - ---------- - alpha0 : (ns,) numpy.ndarray, float64 - Source dual potential - beta0 : (nt,) numpy.ndarray, float64 - Target dual potential - alpha0 : (ns,) numpy.ndarray, float64 - Source dual potential - beta0 : (nt,) numpy.ndarray, float64 - Target dual potential - a : (ns,) numpy.ndarray, float64 - Source distribution (uniform weights if empty list) - b : (nt,) numpy.ndarray, float64 - Target distribution (uniform weights if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) - - Returns - ------- - alpha : (ns,) numpy.ndarray, float64 - Source corrected dual potential - beta : (nt,) numpy.ndarray, float64 - Target corrected dual potential - - """ - - # binary indexing of non-zeros weights - asel = a != 0 - bsel = b != 0 - - # compute dual constraints violation - constraint_violation = alpha0[:, None] + beta0[None, :] - M - - # Compute largest violation per line and columns - aviol = np.max(constraint_violation, 1) - bviol = np.max(constraint_violation, 0) - - # update corrects violation of - alpha_up = -1 * ~asel * np.maximum(aviol, 0) - beta_up = -1 * ~bsel * np.maximum(bviol, 0) - - alpha = alpha0 + alpha_up - beta = beta0 + beta_up - - return center_ot_dual(alpha, beta, a, b) - - -def emd( - a, - b, - M, - numItermax=100000, - log=False, - center_dual=True, - numThreads=1, - check_marginals=True, -): - r"""Solves the Earth Movers distance problem and returns the OT matrix - - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - - s.t. \ \gamma \mathbf{1} = \mathbf{a} - - \gamma^T \mathbf{1} = \mathbf{b} - - \gamma \geq 0 - - where : - - - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - - .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order - numpy.array in float64 format. It will be converted if not in this - format - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - .. note:: This function will cast the computed transport plan to the data type - of the provided input with the following priority: :math:`\mathbf{a}`, - then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. - Casting to an integer tensor might result in a loss of precision. - If this behaviour is unwanted, please make sure to provide a - floating point input. - - .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. - - Uses the algorithm proposed in :ref:`[1] `. - - Parameters - ---------- - a : (ns,) array-like, float - Source histogram (uniform weight if empty list) - b : (nt,) array-like, float - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float - Loss matrix (c-order array in numpy with type float64) - numItermax : int, optional (default=100000) - The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual variables. - Otherwise returns only the optimal transportation matrix. - center_dual: boolean, optional (default=True) - If True, centers the dual potential using function - :py:func:`ot.lp.center_ot_dual`. - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - check_marginals: bool, optional (default=True) - If True, checks that the marginals mass are equal. If False, skips the - check. - - - Returns - ------- - gamma: array-like, shape (ns, nt) - Optimal transportation matrix for the given - parameters - log: dict, optional - If input log is true, a dictionary containing the - cost and dual variables and exit status - - - Examples - -------- - - Simple example with obvious solution. The function emd accepts lists and - perform automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd(a, b, M) - array([[0.5, 0. ], - [0. , 0.5]]) - - - .. _references-emd: - References - ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, - December). Displacement interpolation using Lagrangian mass transport. - In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. - - See Also - -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - """ - - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) - - if len(a) != 0: - type_as = a - elif len(b) != 0: - type_as = b - else: - type_as = M - - # if empty array given then use uniform distributions - if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] - if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) - - # ensure float64 - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") - - # if empty array given then use uniform distributions - if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - - assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] - ), "Dimension mismatch, check dimensions of M with a and b" - - # ensure that same mass - if check_marginals: - np.testing.assert_almost_equal( - a.sum(0), - b.sum(0), - err_msg="a and b vector must have the same sum", - decimal=6, - ) - b = b * a.sum() / b.sum() - - asel = a != 0 - bsel = b != 0 - - numThreads = check_number_threads(numThreads) - - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - - result_code_string = check_result(result_code) - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - if log: - log = {} - log["cost"] = cost - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - return nx.from_numpy(G, type_as=type_as), log - return nx.from_numpy(G, type_as=type_as) - - -def emd2( - a, - b, - M, - processes=1, - numItermax=100000, - log=False, - return_matrix=False, - center_dual=True, - numThreads=1, - check_marginals=True, -): - r"""Solves the Earth Movers distance problem and returns the loss - - .. math:: - \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - - s.t. \ \gamma \mathbf{1} = \mathbf{a} - - \gamma^T \mathbf{1} = \mathbf{b} - - \gamma \geq 0 - - where : - - - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - .. note:: This function will cast the computed transport plan and - transportation loss to the data type of the provided input with the - following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, - then :math:`\mathbf{M}` if marginals are not provided. - Casting to an integer tensor might result in a loss of precision. - If this behaviour is unwanted, please make sure to provide a - floating point input. - - .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. - - Uses the algorithm proposed in :ref:`[1] `. - - Parameters - ---------- - a : (ns,) array-like, float64 - Source histogram (uniform weight if empty list) - b : (nt,) array-like, float64 - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float64 - Loss matrix (for numpy c-order array with type float64) - processes : int, optional (default=1) - Nb of processes used for multiple emd computation (deprecated) - numItermax : int, optional (default=100000) - The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: boolean, optional (default=False) - If True, returns a dictionary containing dual - variables. Otherwise returns only the optimal transportation cost. - return_matrix: boolean, optional (default=False) - If True, returns the optimal transportation matrix in the log. - center_dual: boolean, optional (default=True) - If True, centers the dual potential using function - :py:func:`ot.lp.center_ot_dual`. - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - check_marginals: bool, optional (default=True) - If True, checks that the marginals mass are equal. If False, skips the - check. - - - Returns - ------- - W: float, array-like - Optimal transportation loss for the given parameters - log: dict - If input log is true, a dictionary containing dual - variables and exit status - - - Examples - -------- - - Simple example with obvious solution. The function emd accepts lists and - perform automatic conversion to numpy arrays - - - >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd2(a,b,M) - 0.0 - - - .. _references-emd2: - References - ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. - - See Also - -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - """ - - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) - - if len(a) != 0: - type_as = a - elif len(b) != 0: - type_as = b - else: - type_as = M - - # if empty array given then use uniform distributions - if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] - if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - - # store original tensors - a0, b0, M0 = a, b, M - - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") - - assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] - ), "Dimension mismatch, check dimensions of M with a and b" - - # ensure that same mass - if check_marginals: - np.testing.assert_almost_equal( - a.sum(0), - b.sum(0, keepdims=True), - err_msg="a and b vector must have the same sum", - decimal=6, - ) - b = b * a.sum(0) / b.sum(0, keepdims=True) - - asel = a != 0 - - numThreads = check_number_threads(numThreads) - - if log or return_matrix: - - def f(b): - bsel = b != 0 - - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - - result_code_string = check_result(result_code) - log = {} - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - G = nx.from_numpy(G, type_as=type_as) - if return_matrix: - log["G"] = G - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), - ) - return [cost, log] - else: - - def f(b): - bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - G = nx.from_numpy(G, type_as=type_as) - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - ( - nx.from_numpy(u - np.mean(u), type_as=type_as), - nx.from_numpy(v - np.mean(v), type_as=type_as), - G, - ), - ) - - check_result(result_code) - return cost - - if len(b.shape) == 1: - return f(b) - nb = b.shape[1] - - if processes > 1: - warnings.warn( - "The 'processes' parameter has been deprecated. " - "Multiprocessing should be done outside of POT." - ) - res = list(map(f, [b[:, i].copy() for i in range(nb)])) - - return res - - -def free_support_barycenter( - measures_locations, - measures_weights, - X_init, - b=None, - weights=None, - numItermax=100, - stopThr=1e-7, - verbose=False, - log=None, - numThreads=1, -): - r""" - Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: - - .. math:: - \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) - - where : - - - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) - - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations - - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - - This problem is considered in :ref:`[20] ` (Algorithm 2). - There are two differences with the following codes: - - - we do not optimize over the weights - - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in - :ref:`[20] ` (Algorithm 2). This can be seen as a discrete - implementation of the fixed-point algorithm of - :ref:`[43] ` proposed in the continuous setting. - - Parameters - ---------- - measures_locations : list of N (k_i,d) array-like - The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space - (:math:`k_i` can be different for each element of the list) - measures_weights : list of N (k_i,) array-like - Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one - representing the weights of each discrete input measure - - X_init : (k,d) array-like - Initialization of the support locations (on `k` atoms) of the barycenter - b : (k,) array-like - Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (N,) array-like - Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - - - Returns - ------- - X : (k,d) array-like - Support locations (on k atoms) of the barycenter - - - .. _references-free-support-barycenter: - - References - ---------- - .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - - .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. - - """ - - nx = get_backend(*measures_locations, *measures_weights, X_init) - - iter_count = 0 - - N = len(measures_locations) - k = X_init.shape[0] - d = X_init.shape[1] - if b is None: - b = nx.ones((k,), type_as=X_init) / k - if weights is None: - weights = nx.ones((N,), type_as=X_init) / N - - X = X_init - - log_dict = {} - displacement_square_norms = [] - - displacement_square_norm = stopThr + 1.0 - - while displacement_square_norm > stopThr and iter_count < numItermax: - T_sum = nx.zeros((k, d), type_as=X_init) - - for measure_locations_i, measure_weights_i, weight_i in zip( - measures_locations, measures_weights, weights - ): - M_i = dist(X, measure_locations_i) - T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot( - T_i, measure_locations_i - ) - - displacement_square_norm = nx.sum((T_sum - X) ** 2) - if log: - displacement_square_norms.append(displacement_square_norm) - - X = T_sum - - if verbose: - print( - "iteration %d, displacement_square_norm=%f\n", - iter_count, - displacement_square_norm, - ) - - iter_count += 1 - - if log: - log_dict["displacement_square_norms"] = displacement_square_norms - return X, log_dict - else: - return X - - -def generalized_free_support_barycenter( - X_list, - a_list, - P_list, - n_samples_bary, - Y_init=None, - b=None, - weights=None, - numItermax=100, - stopThr=1e-7, - verbose=False, - log=None, - numThreads=1, - eps=0, -): - r""" - Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with - a fixed amount of points of uniform weights) whose respective projections fit the input measures. - More formally: - - .. math:: - \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) - - where : - - - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` - - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter - - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` - - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) - - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations - - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) - - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` - - As show by :ref:`[42] `, - this problem can be re-written as a Wasserstein Barycenter problem, - which we solve using the free support method :ref:`[20] ` - (Algorithm 2). - - Parameters - ---------- - X_list : list of p (k_i,d_i) array-like - Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space - (:math:`k_i` can be different for each element of the list) - a_list : list of p (k_i,) array-like - Measure weights: each element is a vector (k_i) on the simplex - P_list : list of p (d_i,d) array-like - Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` - n_samples_bary : int - Number of barycenter points - Y_init : (n_samples_bary,d) array-like - Initialization of the support locations (on `k` atoms) of the barycenter - b : (n_samples_bary,) array-like - Initialization of the weights of the barycenter measure (on the simplex) - weights : (p,) array-like - Initialization of the coefficients of the barycenter (on the simplex) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) - If compiled with OpenMP, chooses the number of threads to parallelize. - "max" selects the highest number possible. - eps: Stability coefficient for the change of variable matrix inversion - If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix - inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) - - - Returns - ------- - Y : (n_samples_bary,d) array-like - Support locations (on n_samples_bary atoms) of the barycenter - - - .. _references-generalized-free-support-barycenter: - References - ---------- - .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - - .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. - - """ - nx = get_backend(*X_list, *a_list, *P_list) - d = P_list[0].shape[1] - p = len(X_list) - - if weights is None: - weights = nx.ones(p, type_as=X_list[0]) / p - - # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) - A = eps * nx.eye( - d, type_as=X_list[0] - ) # if eps nonzero: will force the invertibility of A - for P_i, lambda_i in zip(P_list, weights): - A = A + lambda_i * P_i.T @ P_i - B = nx.inv(nx.sqrtm(A)) - - Z_list = [ - x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list) - ] # change of variables -> (WB) problem on Z - - if Y_init is None: - Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) - - if b is None: - b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimized - - out = free_support_barycenter( - Z_list, - a_list, - Y_init, - b, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, - numThreads=numThreads, - ) - - if log: # unpack - Y, log_dict = out - else: - Y = out - log_dict = None - Y = Y @ B.T # return to the Generalized WB formulation - - if log: - return Y, log_dict - else: - return Y diff --git a/ot/lp/barycenter.py b/ot/lp/barycenter.py new file mode 100644 index 000000000..5468fb4eb --- /dev/null +++ b/ot/lp/barycenter.py @@ -0,0 +1,266 @@ + +def free_support_barycenter( + measures_locations, + measures_weights, + X_init, + b=None, + weights=None, + numItermax=100, + stopThr=1e-7, + verbose=False, + log=None, + numThreads=1, +): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[20] ` (Algorithm 2). + There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[43] ` proposed in the continuous setting. + + Parameters + ---------- + measures_locations : list of N (k_i,d) array-like + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) array-like + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure + + X_init : (k,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (k,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (N,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + + + Returns + ------- + X : (k,d) array-like + Support locations (on k atoms) of the barycenter + + + .. _references-free-support-barycenter: + + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + nx = get_backend(*measures_locations, *measures_weights, X_init) + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = nx.ones((k,), type_as=X_init) / k + if weights is None: + weights = nx.ones((N,), type_as=X_init) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1.0 + + while displacement_square_norm > stopThr and iter_count < numItermax: + T_sum = nx.zeros((k, d), type_as=X_init) + + for measure_locations_i, measure_weights_i, weight_i in zip( + measures_locations, measures_weights, weights + ): + M_i = dist(X, measure_locations_i) + T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) + T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot( + T_i, measure_locations_i + ) + + displacement_square_norm = nx.sum((T_sum - X) ** 2) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print( + "iteration %d, displacement_square_norm=%f\n", + iter_count, + displacement_square_norm, + ) + + iter_count += 1 + + if log: + log_dict["displacement_square_norms"] = displacement_square_norms + return X, log_dict + else: + return X + + +def generalized_free_support_barycenter( + X_list, + a_list, + P_list, + n_samples_bary, + Y_init=None, + b=None, + weights=None, + numItermax=100, + stopThr=1e-7, + verbose=False, + log=None, + numThreads=1, + eps=0, +): + r""" + Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with + a fixed amount of points of uniform weights) whose respective projections fit the input measures. + More formally: + + .. math:: + \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) + + where : + + - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` + - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter + - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` + - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) + - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations + - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) + - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` + + As show by :ref:`[42] `, + this problem can be re-written as a Wasserstein Barycenter problem, + which we solve using the free support method :ref:`[20] ` + (Algorithm 2). + + Parameters + ---------- + X_list : list of p (k_i,d_i) array-like + Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space + (:math:`k_i` can be different for each element of the list) + a_list : list of p (k_i,) array-like + Measure weights: each element is a vector (k_i) on the simplex + P_list : list of p (d_i,d) array-like + Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` + n_samples_bary : int + Number of barycenter points + Y_init : (n_samples_bary,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (n_samples_bary,) array-like + Initialization of the weights of the barycenter measure (on the simplex) + weights : (p,) array-like + Initialization of the coefficients of the barycenter (on the simplex) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + eps: Stability coefficient for the change of variable matrix inversion + If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix + inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) + + + Returns + ------- + Y : (n_samples_bary,d) array-like + Support locations (on n_samples_bary atoms) of the barycenter + + + .. _references-generalized-free-support-barycenter: + References + ---------- + .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. + + """ + nx = get_backend(*X_list, *a_list, *P_list) + d = P_list[0].shape[1] + p = len(X_list) + + if weights is None: + weights = nx.ones(p, type_as=X_list[0]) / p + + # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) + A = eps * nx.eye( + d, type_as=X_list[0] + ) # if eps nonzero: will force the invertibility of A + for P_i, lambda_i in zip(P_list, weights): + A = A + lambda_i * P_i.T @ P_i + B = nx.inv(nx.sqrtm(A)) + + Z_list = [ + x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list) + ] # change of variables -> (WB) problem on Z + + if Y_init is None: + Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) + + if b is None: + b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimized + + out = free_support_barycenter( + Z_list, + a_list, + Y_init, + b, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + numThreads=numThreads, + ) + + if log: # unpack + Y, log_dict = out + else: + Y = out + log_dict = None + Y = Y @ B.T # return to the Generalized WB formulation + + if log: + return Y, log_dict + else: + return Y diff --git a/ot/lp/network_simplex.py b/ot/lp/network_simplex.py new file mode 100644 index 000000000..0e820fec6 --- /dev/null +++ b/ot/lp/network_simplex.py @@ -0,0 +1,612 @@ +# -*- coding: utf-8 -*- +""" +Solvers for the original linear program OT problem. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +import numpy as np +import warnings + +from ..utils import list_to_array +from ..backend import get_backend +from .emd_wrap import emd_c, check_result + + +def check_number_threads(numThreads): + """Checks whether or not the requested number of threads has a valid value. + + Parameters + ---------- + numThreads : int or str + The requested number of threads, should either be a strictly positive integer or "max" or None + + Returns + ------- + numThreads : int + Corrected number of threads + """ + if (numThreads is None) or ( + isinstance(numThreads, str) and numThreads.lower() == "max" + ): + return -1 + if (not isinstance(numThreads, int)) or numThreads < 1: + raise ValueError( + 'numThreads should either be "max" or a strictly positive integer' + ) + return numThreads + + +def center_ot_dual(alpha0, beta0, a=None, b=None): + r"""Center dual OT potentials w.r.t. their weights + + The main idea of this function is to find unique dual potentials + that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having + stability when multiple calling of the OT solver with small changes. + + Basically we add another constraint to the potential that will not + change the objective value but will ensure unicity. The constraint + is the following: + + .. math:: + \alpha^T \mathbf{a} = \beta^T \mathbf{b} + + in addition to the OT problem constraints. + + since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing + a constant from both :math:`\alpha_0` and :math:`\beta_0`. + + .. math:: + c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}} + + \alpha &= \alpha_0 + c + + \beta &= \beta_0 + c + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source centered dual potential + beta : (nt,) numpy.ndarray, float64 + Target centered dual potential + + """ + # if no weights are provided, use uniform + if a is None: + a = np.ones(alpha0.shape[0]) / alpha0.shape[0] + if b is None: + b = np.ones(beta0.shape[0]) / beta0.shape[0] + + # compute constant that balances the weighted sums of the duals + c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) + + # update duals + alpha = alpha0 + c + beta = beta0 - c + + return alpha, beta + + +def estimate_dual_null_weights(alpha0, beta0, a, b, M): + r"""Estimate feasible values for 0-weighted dual potentials + + The feasible values are computed efficiently but rather coarsely. + + .. warning:: + This function is necessary because the C++ solver in `emd_c` + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + + First we compute the constraints violations: + + .. math:: + \mathbf{V} = \alpha + \beta^T - \mathbf{M} + + Next we compute the max amount of violation per row (:math:`\alpha`) and + columns (:math:`beta`) + + .. math:: + \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j} + + \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j} + + Finally we update the dual potential with 0 weights if a + constraint is violated + + .. math:: + \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0 + + \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0 + + In the end the dual potentials are centered using function + :py:func:`ot.lp.center_ot_dual`. + + Note that all those updates do not change the objective value of the + solution but provide dual potentials that do not violate the constraints. + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source distribution (uniform weights if empty list) + b : (nt,) numpy.ndarray, float64 + Target distribution (uniform weights if empty list) + M : (ns,nt) numpy.ndarray, float64 + Loss matrix (c-order array with type float64) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source corrected dual potential + beta : (nt,) numpy.ndarray, float64 + Target corrected dual potential + + """ + + # binary indexing of non-zeros weights + asel = a != 0 + bsel = b != 0 + + # compute dual constraints violation + constraint_violation = alpha0[:, None] + beta0[None, :] - M + + # Compute largest violation per line and columns + aviol = np.max(constraint_violation, 1) + bviol = np.max(constraint_violation, 0) + + # update corrects violation of + alpha_up = -1 * ~asel * np.maximum(aviol, 0) + beta_up = -1 * ~bsel * np.maximum(bviol, 0) + + alpha = alpha0 + alpha_up + beta = beta0 + beta_up + + return center_ot_dual(alpha, beta, a, b) + + +def emd( + a, + b, + M, + numItermax=100000, + log=False, + center_dual=True, + numThreads=1, + check_marginals=True, +): + r"""Solves the Earth Movers distance problem and returns the OT matrix + + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order + numpy.array in float64 format. It will be converted if not in this + format + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + .. note:: This function will cast the computed transport plan to the data type + of the provided input with the following priority: :math:`\mathbf{a}`, + then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + + Uses the algorithm proposed in :ref:`[1] `. + + Parameters + ---------- + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :py:func:`ot.lp.center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + check_marginals: bool, optional (default=True) + If True, checks that the marginals mass are equal. If False, skips the + check. + + + Returns + ------- + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status + + + Examples + -------- + + Simple example with obvious solution. The function emd accepts lists and + perform automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.emd(a, b, M) + array([[0.5, 0. ], + [0. , 0.5]]) + + + .. _references-emd: + References + ---------- + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, + December). Displacement interpolation using Lagrangian mass transport. + In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) + + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b + else: + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # convert to numpy + M, a, b = nx.to_numpy(M, a, b) + + # ensure float64 + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order="C") + + # if empty array given then use uniform distributions + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + assert ( + a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + ), "Dimension mismatch, check dimensions of M with a and b" + + # ensure that same mass + if check_marginals: + np.testing.assert_almost_equal( + a.sum(0), + b.sum(0), + err_msg="a and b vector must have the same sum", + decimal=6, + ) + b = b * a.sum() / b.sum() + + asel = a != 0 + bsel = b != 0 + + numThreads = check_number_threads(numThreads) + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + result_code_string = check_result(result_code) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + if log: + log = {} + log["cost"] = cost + log["u"] = nx.from_numpy(u, type_as=type_as) + log["v"] = nx.from_numpy(v, type_as=type_as) + log["warning"] = result_code_string + log["result_code"] = result_code + return nx.from_numpy(G, type_as=type_as), log + return nx.from_numpy(G, type_as=type_as) + + +def emd2( + a, + b, + M, + processes=1, + numItermax=100000, + log=False, + return_matrix=False, + center_dual=True, + numThreads=1, + check_marginals=True, +): + r"""Solves the Earth Movers distance problem and returns the loss + + .. math:: + \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + .. note:: This function will cast the computed transport plan and + transportation loss to the data type of the provided input with the + following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, + then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + + Uses the algorithm proposed in :ref:`[1] `. + + Parameters + ---------- + a : (ns,) array-like, float64 + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float64 + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float64 + Loss matrix (for numpy c-order array with type float64) + processes : int, optional (default=1) + Nb of processes used for multiple emd computation (deprecated) + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing dual + variables. Otherwise returns only the optimal transportation cost. + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :py:func:`ot.lp.center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + check_marginals: bool, optional (default=True) + If True, checks that the marginals mass are equal. If False, skips the + check. + + + Returns + ------- + W: float, array-like + Optimal transportation loss for the given parameters + log: dict + If input log is true, a dictionary containing dual + variables and exit status + + + Examples + -------- + + Simple example with obvious solution. The function emd accepts lists and + perform automatic conversion to numpy arrays + + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.emd2(a,b,M) + 0.0 + + + .. _references-emd2: + References + ---------- + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) + + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b + else: + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # store original tensors + a0, b0, M0 = a, b, M + + # convert to numpy + M, a, b = nx.to_numpy(M, a, b) + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order="C") + + assert ( + a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + ), "Dimension mismatch, check dimensions of M with a and b" + + # ensure that same mass + if check_marginals: + np.testing.assert_almost_equal( + a.sum(0), + b.sum(0, keepdims=True), + err_msg="a and b vector must have the same sum", + decimal=6, + ) + b = b * a.sum(0) / b.sum(0, keepdims=True) + + asel = a != 0 + + numThreads = check_number_threads(numThreads) + + if log or return_matrix: + + def f(b): + bsel = b != 0 + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + result_code_string = check_result(result_code) + log = {} + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + G = nx.from_numpy(G, type_as=type_as) + if return_matrix: + log["G"] = G + log["u"] = nx.from_numpy(u, type_as=type_as) + log["v"] = nx.from_numpy(v, type_as=type_as) + log["warning"] = result_code_string + log["result_code"] = result_code + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), + ) + return [cost, log] + else: + + def f(b): + bsel = b != 0 + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + G = nx.from_numpy(G, type_as=type_as) + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + G, + ), + ) + + check_result(result_code) + return cost + + if len(b.shape) == 1: + return f(b) + nb = b.shape[1] + + if processes > 1: + warnings.warn( + "The 'processes' parameter has been deprecated. " + "Multiprocessing should be done outside of POT." + ) + res = list(map(f, [b[:, i].copy() for i in range(nb)])) + + return res From 109edb7534653c767d490703cfd631aad55a6592 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 11:53:38 +0100 Subject: [PATCH 02/23] pr number + enabled pre-commit --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index e29be544e..2eae33215 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,7 +7,7 @@ - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) - Implement fixed-point solver for OT barycenters with generic cost functions - (generalizes `ot.lp.free_support_barycenter`). (PR #???) + (generalizes `ot.lp.free_support_barycenter`). (PR #714) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From 0957904c9d4fb2bdba58a357899077192c1ee52d Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 11:57:45 +0100 Subject: [PATCH 03/23] added barycenter.py imports --- ot/lp/barycenter.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ot/lp/barycenter.py b/ot/lp/barycenter.py index 5468fb4eb..b1411abe1 100644 --- a/ot/lp/barycenter.py +++ b/ot/lp/barycenter.py @@ -1,3 +1,7 @@ +from ..backend import get_backend +from ..utils import dist +from .network_simplex import emd + def free_support_barycenter( measures_locations, From 818b3e7a278af75ad5a95c50f3a599775193a768 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 12:10:21 +0100 Subject: [PATCH 04/23] fixed wrong import in ot.gmm --- ot/gmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/gmm.py b/ot/gmm.py index cde2f8bbd..5c7a4c287 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -12,7 +12,7 @@ from .backend import get_backend from .lp import emd2, emd import numpy as np -from .lp import dist +from .utils import dist from .gaussian import bures_wasserstein_mapping From 08c2285cafe4a1ee6517e799a043af3251031a6e Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 12:24:20 +0100 Subject: [PATCH 05/23] ruff fix attempt --- README.md | 7 ++++++- ot/gromov/_partial.py | 6 +++--- ot/gromov/_quantized.py | 6 +++--- ot/lp/__init__.py | 6 +++--- ot/lp/{barycenter.py => barycenter_solvers.py} | 0 ot/partial.py | 14 +++++++------- ot/utils.py | 4 ++-- 7 files changed, 24 insertions(+), 19 deletions(-) rename ot/lp/{barycenter.py => barycenter_solvers.py} (100%) diff --git a/README.md b/README.md index 7bbae9e8a..dd9622d9d 100644 --- a/README.md +++ b/README.md @@ -51,10 +51,11 @@ POT provides the following generic OT solvers (links to examples): * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. * [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59]. -* Gaussian Mixture Model OT [69] +* [Gaussian Mixture Model OT](https://pythonot.github.io/auto_examples/others/plot_GMMOT_plan.html#sphx-glr-auto-examples-others-plot-gmmot-plan-py) [69]. * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. +* OT Barycenters for generic transport costs []. POT provides the following Machine Learning related solvers: @@ -391,3 +392,7 @@ Artificial Intelligence. [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing +Barycentres of Measures for Generic Transport +Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) \ No newline at end of file diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index c6837f1d3..6672240d0 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -185,7 +185,7 @@ def partial_gromov_wasserstein( if m is None: m = min(np.sum(p), np.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > min(np.sum(p), np.sum(q)): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -654,7 +654,7 @@ def partial_fused_gromov_wasserstein( if m is None: m = min(np.sum(p), np.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > min(np.sum(p), np.sum(q)): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -1213,7 +1213,7 @@ def entropic_partial_gromov_wasserstein( if m is None: m = min(nx.sum(p), nx.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > min(nx.sum(p), nx.sum(q)): raise ValueError( "Problem infeasible. Parameter m should lower or" diff --git a/ot/gromov/_quantized.py b/ot/gromov/_quantized.py index ac2db5d2d..f4a8fafa7 100644 --- a/ot/gromov/_quantized.py +++ b/ot/gromov/_quantized.py @@ -375,7 +375,7 @@ def get_graph_partition( raise ValueError( f""" Unknown `part_method='{part_method}'`. Use one of: - {'random', 'louvain', 'fluid', 'spectral', 'GW', 'FGW'}. + {"random", "louvain", "fluid", "spectral", "GW", "FGW"}. """ ) return nx.from_numpy(part, type_as=C0) @@ -447,7 +447,7 @@ def get_graph_representants(C, part, rep_method="pagerank", random_state=0, nx=N raise ValueError( f""" Unknown `rep_method='{rep_method}'`. Use one of: - {'random', 'pagerank'}. + {"random", "pagerank"}. """ ) @@ -953,7 +953,7 @@ def get_partition_and_representants_samples( else: raise ValueError( f""" - Unknown `method='{method}'`. Use one of: {'random', 'kmeans'} + Unknown `method='{method}'`. Use one of: {"random", "kmeans"} """ ) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d11a5ee41..b29029243 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -12,9 +12,9 @@ from .cvx import barycenter from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize from .network_simplex import emd, emd2 -from .barycenter import ( - free_support_barycenter, - generalized_free_support_barycenter +from .barycenter_solvers import ( + free_support_barycenter, + generalized_free_support_barycenter, ) # import compiled emd diff --git a/ot/lp/barycenter.py b/ot/lp/barycenter_solvers.py similarity index 100% rename from ot/lp/barycenter.py rename to ot/lp/barycenter_solvers.py diff --git a/ot/partial.py b/ot/partial.py index c11ab228a..6b2304e08 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -126,7 +126,7 @@ def partial_wasserstein_lagrange( nx = get_backend(a, b, M) if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors - raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") + raise ValueError("Problem infeasible. Check that a and b are in the simplex") if reg_m is None: reg_m = float(nx.max(M)) + 1 @@ -171,7 +171,7 @@ def partial_wasserstein_lagrange( if log_emd["warning"] is not None: raise ValueError( - "Error in the EMD resolution: try to increase the" " number of dummy points" + "Error in the EMD resolution: try to increase the number of dummy points" ) log_emd["cost"] = nx.sum(gamma * M0) log_emd["u"] = nx.from_numpy(log_emd["u"], type_as=a0) @@ -287,7 +287,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -315,7 +315,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): if log_emd["warning"] is not None: raise ValueError( - "Error in the EMD resolution: try to increase the" " number of dummy points" + "Error in the EMD resolution: try to increase the number of dummy points" ) log_emd["partial_w_dist"] = nx.sum(M * gamma) log_emd["u"] = log_emd["u"][: len(a)] @@ -522,7 +522,7 @@ def entropic_partial_wasserstein( if m is None: m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 if m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -780,7 +780,7 @@ def partial_gromov_wasserstein( if m is None: m = np.min((np.sum(p), np.sum(q))) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > np.min((np.sum(p), np.sum(q))): raise ValueError( "Problem infeasible. Parameter m should lower or" @@ -1132,7 +1132,7 @@ def entropic_partial_gromov_wasserstein( if m is None: m = np.min((np.sum(p), np.sum(q))) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") elif m > np.min((np.sum(p), np.sum(q))): raise ValueError( "Problem infeasible. Parameter m should lower or" diff --git a/ot/utils.py b/ot/utils.py index a2d328484..42673ecd6 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -517,7 +517,7 @@ def check_random_state(seed): if isinstance(seed, np.random.RandomState): return seed raise ValueError( - "{} cannot be used to seed a numpy.random.RandomState" " instance".format(seed) + "{} cannot be used to seed a numpy.random.RandomState instance".format(seed) ) @@ -787,7 +787,7 @@ def _update_doc(self, olddoc): def _is_deprecated(func): r"""Helper to check if func is wrapped by our deprecated decorator""" if sys.version_info < (3, 5): - raise NotImplementedError("This is only available for python3.5 " "or above") + raise NotImplementedError("This is only available for python3.5 or above") closures = getattr(func, "__closure__", []) if closures is None: closures = [] From f26851586a7c03d4707a8ed710b8047f9acfc78c Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 13:33:23 +0100 Subject: [PATCH 06/23] removed ot bar contribs -> only o.lp reorganisation in this PR --- README.md | 5 ----- RELEASES.md | 3 +-- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/README.md b/README.md index dd9622d9d..f64db8f56 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,6 @@ POT provides the following generic OT solvers (links to examples): * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. -* OT Barycenters for generic transport costs []. POT provides the following Machine Learning related solvers: @@ -392,7 +391,3 @@ Artificial Intelligence. [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. - -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing -Barycentres of Measures for Generic Transport -Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 2eae33215..1550b479f 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,8 +6,7 @@ - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) -- Implement fixed-point solver for OT barycenters with generic cost functions - (generalizes `ot.lp.free_support_barycenter`). (PR #714) +- Moved functions from `ot/lp/__init__.py` to separate modules. (PR #714) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From 8f24cb95f28e8c1e3f80cb6e72e768f1b45cc2dc Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 13:39:19 +0100 Subject: [PATCH 07/23] add check_number_threads to ot/lp/__init__.py __all__ --- ot/lp/__init__.py | 2 ++ ot/lp/network_simplex.py | 26 +------------------------- ot/utils.py | 24 ++++++++++++++++++++++++ 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index b29029243..548200123 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -16,6 +16,7 @@ free_support_barycenter, generalized_free_support_barycenter, ) +from ..utils import check_number_threads # import compiled emd from .emd_wrap import emd_1d_sorted @@ -44,4 +45,5 @@ "semidiscrete_wasserstein2_unif_circle", "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", + "check_number_threads", ] diff --git a/ot/lp/network_simplex.py b/ot/lp/network_simplex.py index 0e820fec6..492e4c7ac 100644 --- a/ot/lp/network_simplex.py +++ b/ot/lp/network_simplex.py @@ -11,35 +11,11 @@ import numpy as np import warnings -from ..utils import list_to_array +from ..utils import list_to_array, check_number_threads from ..backend import get_backend from .emd_wrap import emd_c, check_result -def check_number_threads(numThreads): - """Checks whether or not the requested number of threads has a valid value. - - Parameters - ---------- - numThreads : int or str - The requested number of threads, should either be a strictly positive integer or "max" or None - - Returns - ------- - numThreads : int - Corrected number of threads - """ - if (numThreads is None) or ( - isinstance(numThreads, str) and numThreads.lower() == "max" - ): - return -1 - if (not isinstance(numThreads, int)) or numThreads < 1: - raise ValueError( - 'numThreads should either be "max" or a strictly positive integer' - ) - return numThreads - - def center_ot_dual(alpha0, beta0, a=None, b=None): r"""Center dual OT potentials w.r.t. their weights diff --git a/ot/utils.py b/ot/utils.py index 42673ecd6..66ff7e354 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1341,3 +1341,27 @@ def proj_SDP(S, nx=None, vmin=0.0): Q = nx.einsum("ijk,ik->ijk", P, w) # Q[i] = P[i] @ diag(w[i]) # R[i] = Q[i] @ P[i].T return nx.einsum("ijk,ikl->ijl", Q, nx.transpose(P, (0, 2, 1))) + + +def check_number_threads(numThreads): + """Checks whether or not the requested number of threads has a valid value. + + Parameters + ---------- + numThreads : int or str + The requested number of threads, should either be a strictly positive integer or "max" or None + + Returns + ------- + numThreads : int + Corrected number of threads + """ + if (numThreads is None) or ( + isinstance(numThreads, str) and numThreads.lower() == "max" + ): + return -1 + if (not isinstance(numThreads, int)) or numThreads < 1: + raise ValueError( + 'numThreads should either be "max" or a strictly positive integer' + ) + return numThreads From 3e3b4445f4c1edf588c8d58bb218ccadd5ad0111 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 13:41:29 +0100 Subject: [PATCH 08/23] update releases --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 1550b479f..7d138c9c6 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,7 +6,7 @@ - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) -- Moved functions from `ot/lp/__init__.py` to separate modules. (PR #714) +- Reorganize sub-module `ot/lp/__init__.py` into separate files. (PR #714) (PR #714) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From 566a0fc1e3171cd16cd22b58b926a58cc3c9a2cb Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 15:07:46 +0100 Subject: [PATCH 09/23] made barycenter_solvers and network_simplex hidden + deprecated ot.lp.cvx --- RELEASES.md | 2 +- ot/lp/__init__.py | 6 +- ...nter_solvers.py => _barycenter_solvers.py} | 156 +++++++++++++++++- ...network_simplex.py => _network_simplex.py} | 0 ot/lp/cvx.py | 148 +---------------- 5 files changed, 163 insertions(+), 149 deletions(-) rename ot/lp/{barycenter_solvers.py => _barycenter_solvers.py} (69%) rename ot/lp/{network_simplex.py => _network_simplex.py} (100%) diff --git a/RELEASES.md b/RELEASES.md index 7d138c9c6..a0474eda0 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,7 +6,7 @@ - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) -- Reorganize sub-module `ot/lp/__init__.py` into separate files. (PR #714) (PR #714) +- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 548200123..e3cfce0fd 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -9,10 +9,10 @@ # License: MIT License from . import cvx -from .cvx import barycenter from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize -from .network_simplex import emd, emd2 -from .barycenter_solvers import ( +from ._network_simplex import emd, emd2 +from ._barycenter_solvers import ( + barycenter, free_support_barycenter, generalized_free_support_barycenter, ) diff --git a/ot/lp/barycenter_solvers.py b/ot/lp/_barycenter_solvers.py similarity index 69% rename from ot/lp/barycenter_solvers.py rename to ot/lp/_barycenter_solvers.py index b1411abe1..8b64214d9 100644 --- a/ot/lp/barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -1,6 +1,160 @@ +# -*- coding: utf-8 -*- +""" +OT Barycenter Solvers +""" + +# Author: Remi Flamary +# Eloi Tanguy +# +# License: MIT License + from ..backend import get_backend from ..utils import dist -from .network_simplex import emd +from ._network_simplex import emd + +import numpy as np +import scipy as sp +import scipy.sparse as sps + +try: + import cvxopt # for cvxopt barycenter solver + from cvxopt import solvers, matrix, spmatrix +except ImportError: + cvxopt = False + + +def scipy_sparse_to_spmatrix(A): + """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" + coo = A.tocoo() + SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) + return SP + + +def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"): + r"""Compute the Wasserstein barycenter of distributions A + + The function solves the following optimization problem [16]: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + + The linear program is solved using the interior point solver from scipy.optimize. + If cvxopt solver if installed it can use cvxopt + + Note that this problem do not scale well (both in memory and computational time). + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions a_i of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coordinates) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + solver : string, optional + the solver used, default 'interior-point' use the lp solver from + scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. + + Returns + ------- + a : (d,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. + + + """ + + if weights is None: + weights = np.ones(A.shape[1]) / A.shape[1] + else: + assert len(weights) == A.shape[1] + + n_distributions = A.shape[1] + n = A.shape[0] + + n2 = n * n + c = np.zeros((0)) + b_eq1 = np.zeros((0)) + for i in range(n_distributions): + c = np.concatenate((c, M.ravel() * weights[i])) + b_eq1 = np.concatenate((b_eq1, A[:, i])) + c = np.concatenate((c, np.zeros(n))) + + lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] + # row constraints + A_eq1 = sps.hstack( + (sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))) + ) + + # columns constraints + lst_idiag2 = [] + lst_eye = [] + for i in range(n_distributions): + if i == 0: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) + lst_eye.append(-sps.eye(n)) + else: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) + lst_eye.append(-sps.eye(n - 1, n)) + + A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) + b_eq2 = np.zeros((A_eq2.shape[0])) + + # full problem + A_eq = sps.vstack((A_eq1, A_eq2)) + b_eq = np.concatenate((b_eq1, b_eq2)) + + if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]: + # cvxopt not installed or interior point + + if solver is None: + solver = "interior-point" + + options = {"disp": verbose} + sol = sp.optimize.linprog( + c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options + ) + x = sol.x + b = x[-n:] + + else: + h = np.zeros((n_distributions * n2 + n)) + G = -sps.eye(n_distributions * n2 + n) + + sol = solvers.lp( + matrix(c), + scipy_sparse_to_spmatrix(G), + matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), + b=matrix(b_eq), + solver=solver, + ) + + x = np.array(sol["x"]) + b = x[-n:].ravel() + + if log: + return b, sol + else: + return b def free_support_barycenter( diff --git a/ot/lp/network_simplex.py b/ot/lp/_network_simplex.py similarity index 100% rename from ot/lp/network_simplex.py rename to ot/lp/_network_simplex.py diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 01f5e5d87..b2269b8b4 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -1,152 +1,12 @@ # -*- coding: utf-8 -*- """ -LP solvers for optimal transport using cvxopt +(DEPRECATED) LP solvers for optimal transport using cvxopt """ # Author: Remi Flamary # # License: MIT License -import numpy as np -import scipy as sp -import scipy.sparse as sps - -try: - import cvxopt - from cvxopt import solvers, matrix, spmatrix -except ImportError: - cvxopt = False - - -def scipy_sparse_to_spmatrix(A): - """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" - coo = A.tocoo() - SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) - return SP - - -def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"): - r"""Compute the Wasserstein barycenter of distributions A - - The function solves the following optimization problem [16]: - - .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - The linear program is solved using the interior point solver from scipy.optimize. - If cvxopt solver if installed it can use cvxopt - - Note that this problem do not scale well (both in memory and computational time). - - Parameters - ---------- - A : np.ndarray (d,n) - n training distributions a_i of size d - M : np.ndarray (d,d) - loss matrix for OT - reg : float - Regularization term >0 - weights : np.ndarray (n,) - Weights of each histogram a_i on the simplex (barycentric coordinates) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - solver : string, optional - the solver used, default 'interior-point' use the lp solver from - scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. - - Returns - ------- - a : (d,) ndarray - Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. - - - """ - - if weights is None: - weights = np.ones(A.shape[1]) / A.shape[1] - else: - assert len(weights) == A.shape[1] - - n_distributions = A.shape[1] - n = A.shape[0] - - n2 = n * n - c = np.zeros((0)) - b_eq1 = np.zeros((0)) - for i in range(n_distributions): - c = np.concatenate((c, M.ravel() * weights[i])) - b_eq1 = np.concatenate((b_eq1, A[:, i])) - c = np.concatenate((c, np.zeros(n))) - - lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] - # row constraints - A_eq1 = sps.hstack( - (sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))) - ) - - # columns constraints - lst_idiag2 = [] - lst_eye = [] - for i in range(n_distributions): - if i == 0: - lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) - lst_eye.append(-sps.eye(n)) - else: - lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) - lst_eye.append(-sps.eye(n - 1, n)) - - A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) - b_eq2 = np.zeros((A_eq2.shape[0])) - - # full problem - A_eq = sps.vstack((A_eq1, A_eq2)) - b_eq = np.concatenate((b_eq1, b_eq2)) - - if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]: - # cvxopt not installed or interior point - - if solver is None: - solver = "interior-point" - - options = {"disp": verbose} - sol = sp.optimize.linprog( - c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options - ) - x = sol.x - b = x[-n:] - - else: - h = np.zeros((n_distributions * n2 + n)) - G = -sps.eye(n_distributions * n2 + n) - - sol = solvers.lp( - matrix(c), - scipy_sparse_to_spmatrix(G), - matrix(h), - A=scipy_sparse_to_spmatrix(A_eq), - b=matrix(b_eq), - solver=solver, - ) - - x = np.array(sol["x"]) - b = x[-n:].ravel() - - if log: - return b, sol - else: - return b +print( + "The module ot.lp.cvx is deprecated and will be removed in future versions. The function `barycenter` was moved to ot.lp._barycenter_solvers and can be importer via ot.lp." +) From 5c35d586ef1b6adf3b5b7d77edb8d90a504904bd Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 15:10:34 +0100 Subject: [PATCH 10/23] fix ref to lp.cvx in test --- test/test_ot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ot.py b/test/test_ot.py index da0ec746e..f84f8773a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -395,7 +395,7 @@ def test_generalised_free_support_barycenter_backends(nx): np.testing.assert_allclose(Y, nx.to_numpy(Y2)) -@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") +@pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] a2 = np.array([0, 0, 1.0])[:, None] From 8ffb06190ce085af685676ac3072335ef5364680 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 15:23:50 +0100 Subject: [PATCH 11/23] lp.cvx now imports barycenter and gives a warnings.warning --- ot/lp/cvx.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index b2269b8b4..4f7846341 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -7,6 +7,11 @@ # # License: MIT License -print( - "The module ot.lp.cvx is deprecated and will be removed in future versions. The function `barycenter` was moved to ot.lp._barycenter_solvers and can be importer via ot.lp." +import warnings + + +warnings.warn( + "The module ot.lp.cvx is deprecated and will be removed in future versions." + "The function `barycenter` was moved to ot.lp._barycenter_solvers and can" + "be importer via ot.lp." ) From 26748eb0602305ed5d115ad1d7a3b43f352ff06c Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 15:28:04 +0100 Subject: [PATCH 12/23] cvx import barycenter --- ot/lp/cvx.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 4f7846341..e88d15375 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -8,6 +8,10 @@ # License: MIT License import warnings +from ._barycenter_solvers import barycenter + + +__all__ = ["barycenter"] warnings.warn( From 081e4eb14285a50f23891cb398472d42da70e724 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 16:52:19 +0100 Subject: [PATCH 13/23] added fixed-point barycenter function to ot.lp._barycenter_solvers_ --- CONTRIBUTORS.md | 2 +- README.md | 4 ++ RELEASES.md | 2 + ot/lp/_barycenter_solvers.py | 87 ++++++++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 6f6a72737..fc1ecc313 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -44,7 +44,7 @@ The contributors to this library are: * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW, quantized FGW, partial FGW) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein - Barycenters, GMMOT) + Barycenters, GMMOT, Barycenters for General Transport Costs) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) diff --git a/README.md b/README.md index f64db8f56..9a8e5b371 100644 --- a/README.md +++ b/README.md @@ -391,3 +391,7 @@ Artificial Intelligence. [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing +Barycentres of Measures for Generic Transport +Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) diff --git a/RELEASES.md b/RELEASES.md index a0474eda0..2a6867484 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,8 @@ - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) +- Implement fixed-point solver for OT barycenters with generic cost functions + (generalizes `ot.lp.free_support_barycenter`). (PR #715) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 8b64214d9..7e801caa6 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -422,3 +422,90 @@ def generalized_free_support_barycenter( return Y, log_dict else: return Y + + +class StoppingCriterionReached(Exception): + pass + + +def solve_OT_barycenter_fixed_point( + X_init, + Y_list, + b_list, + cost_list, + B, + max_its=5, + stop_threshold=1e-5, + log=False, +): + """ + Solves the OT barycenter problem using the fixed point algorithm, iterating + the function B on plans between the current barycentre and the measures. + + Parameters + ---------- + X_init : array-like + Array of shape (n, d) representing initial barycentre points. + Y_list : list of array-like + List of K arrays of measure positions, each of shape (m_k, d_k). + b_list : list of array-like + List of K arrays of measure weights, each of shape (m_k). + cost_list : list of callable + List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k). + B : callable + Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre. + max_its : int, optional + Maximum number of iterations (default is 5). + stop_threshold : float, optional + If the iterations move less than this, terminate (default is 1e-5). + log : bool, optional + Whether to return the log dictionary (default is False). + + Returns + ------- + X : array-like + Array of shape (n, d) representing barycentre points. + log_dict : list of array-like, optional + log containing the exit status, list of iterations and list of + displacements if log is True. + """ + nx = get_backend(X_init, Y_list[0]) + K = len(Y_list) + n = X_init.shape[0] + a = nx.ones(n) / n + X_list = [X_init] if log else [] # store the iterations + X = X_init + dX_list = [] # store the displacement squared norms + exit_status = "Unknown" + + try: + for _ in range(max_its): + pi_list = [ # compute the pairwise transport plans + emd(a, b_list[k], cost_list[k](X, Y_list[k])) for k in range(K) + ] + Y_perm = [] + for k in range(K): # compute barycentric projections + Y_perm.append(n * pi_list[k] @ Y_list[k]) + X_next = B(Y_perm) + + if log: + X_list.append(X_next) + + # stationary criterion: move less than the threshold + dX = nx.sum((X - X_next) ** 2) + X = X_next + + if log: + dX_list.append(dX) + + if dX < stop_threshold: + exit_status = "Stationary Point" + raise StoppingCriterionReached + + exit_status = "Max iterations reached" + raise StoppingCriterionReached + + except StoppingCriterionReached: + if log: + return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list} + return X From 59520198b25a6dd3e2c9f8a403e1846bd77e0995 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 20 Jan 2025 18:06:32 +0100 Subject: [PATCH 14/23] ot bar demo --- .../plot_barycenter_generic_cost.py | 167 ++++++++++++++++++ ...lot_generalized_free_support_barycenter.py | 2 +- examples/others/plot_GMMOT_plan.py | 2 +- examples/others/plot_GMM_flow.py | 2 +- examples/others/plot_SSNB.py | 2 +- ot/gmm.py | 4 +- ot/lp/__init__.py | 3 +- ot/lp/_barycenter_solvers.py | 2 +- ot/mapping.py | 2 +- 9 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 examples/barycenters/plot_barycenter_generic_cost.py diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_barycenter_generic_cost.py new file mode 100644 index 000000000..14779fdff --- /dev/null +++ b/examples/barycenters/plot_barycenter_generic_cost.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +""" +===================================== +OT Barycenter with Generic Costs Demo +===================================== + +This example illustrates the computation of an Optimal Transport for a ground +cost that is not a power of a norm. We take the example of ground costs +:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear) +projection onto a circle k. This is an example of the fixed-point barycenter +solver introduced in [74] which generalises [20]. + +The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in +\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over +:math:`x` with Pytorch. + +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing +Barycentres of Measures for Generic Transport +Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) + +[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein +Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International +Conference in Machine Learning + +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% Generate data +import torch +from torch.optim import Adam +from ot.utils import dist +import numpy as np +from ot.lp import free_support_barycenter_generic_costs +import matplotlib.pyplot as plt + + +torch.manual_seed(42) + +n = 100 # number of points of the of the barycentre +d = 2 # dimensions of the original measure +K = 4 # number of measures to barycentre +m = 50 # number of points of the measures +b_list = [torch.ones(m) / m] * K # weights of the 4 measures +weights = torch.ones(K) / K # weights for the barycentre +stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo + + +# map R^2 -> R^2 projection onto circle +def proj_circle(X, origin, radius): + diffs = X - origin[None, :] + norms = torch.norm(diffs, dim=1) + return origin[None, :] + radius * diffs / norms[:, None] + + +# circles on which to project +origin1 = torch.tensor([-1.0, -1.0]) +origin2 = torch.tensor([-1.0, 2.0]) +origin3 = torch.tensor([2.0, 2.0]) +origin4 = torch.tensor([2.0, -1.0]) +r = np.sqrt(2) +P_list = [ + lambda X: proj_circle(X, origin1, r), + lambda X: proj_circle(X, origin2, r), + lambda X: proj_circle(X, origin3, r), + lambda X: proj_circle(X, origin4, r), +] + +# measures to barycentre are projections of different random circles +# onto the K circles +Y_list = [] +for k in range(K): + t = torch.rand(m) * 2 * np.pi + X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1) + X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :] + Y_list.append(P_list[k](X_temp)) + + +# %% Define costs and ground barycenter function +# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a +# (n, n_k) matrix of costs +def c1(x, y): + return dist(P_list[0](x), y) + + +def c2(x, y): + return dist(P_list[1](x), y) + + +def c3(x, y): + return dist(P_list[2](x), y) + + +def c4(x, y): + return dist(P_list[3](x), y) + + +cost_list = [c1, c2, c3, c4] + + +# batched total ground cost function for candidate points x (n, d) +# for computation of the ground barycenter B with gradient descent +def C(x, y): + """ + Computes the barycenter cost for candidate points x (n, d) and + measure supports y: List(n, d_k). + """ + n = x.shape[0] + K = len(y) + out = torch.zeros(n) + for k in range(K): + out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1) + return out + + +# ground barycenter function +def B(y, its=150, lr=1, stop_threshold=stop_threshold): + """ + Computes the ground barycenter for measure supports y: List(n, d_k). + Output: (n, d) array + """ + x = torch.randn(n, d) + x.requires_grad_(True) + opt = Adam([x], lr=lr) + for _ in range(its): + x_prev = x.data.clone() + opt.zero_grad() + loss = torch.sum(C(x, y)) + loss.backward() + opt.step() + diff = torch.sum((x.data - x_prev) ** 2) + if diff < stop_threshold: + break + return x + + +# %% Compute the barycenter measure +fixed_point_its = 10 +X_init = torch.rand(n, d) +X_bar = free_support_barycenter_generic_costs( + X_init, + Y_list, + b_list, + cost_list, + B, + max_its=fixed_point_its, + stop_threshold=stop_threshold, +) + +# %% Plot Barycenter (Iteration 10) +alpha = 0.5 +labels = ["circle 1", "circle 2", "circle 3", "circle 4"] +for Y, label in zip(Y_list, labels): + plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label) +plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha) +plt.axis("equal") +plt.xlim(-0.3, 1.3) +plt.ylim(-0.3, 1.3) +plt.axis("off") +plt.legend() +plt.tight_layout() + +# %% diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py index 5b3572bd4..b21c66f13 100644 --- a/examples/barycenters/plot_generalized_free_support_barycenter.py +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -14,7 +14,7 @@ """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # # License: MIT License diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py index 7742d496e..4964ddd66 100644 --- a/examples/others/plot_GMMOT_plan.py +++ b/examples/others/plot_GMMOT_plan.py @@ -16,7 +16,7 @@ """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # Julie Delon # diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index beb675755..dc26ff3ce 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -10,7 +10,7 @@ """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # Julie Delon # diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index fbc343a8a..e167b1ee4 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -38,7 +38,7 @@ 2017. """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # License: MIT License # sphinx_gallery_thumbnail_number = 3 diff --git a/ot/gmm.py b/ot/gmm.py index 5c7a4c287..d99d4e5db 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -3,8 +3,8 @@ Optimal transport for Gaussian Mixtures """ -# Author: Eloi Tanguy -# Remi Flamary +# Author: Eloi Tanguy +# Remi Flamary # Julie Delon # # License: MIT License diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index e3cfce0fd..974679440 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -8,13 +8,13 @@ # # License: MIT License -from . import cvx from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize from ._network_simplex import emd, emd2 from ._barycenter_solvers import ( barycenter, free_support_barycenter, generalized_free_support_barycenter, + free_support_barycenter_generic_costs, ) from ..utils import check_number_threads @@ -46,4 +46,5 @@ "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", "check_number_threads", + "free_support_barycenter_generic_costs", ] diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 7e801caa6..e45092caa 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -428,7 +428,7 @@ class StoppingCriterionReached(Exception): pass -def solve_OT_barycenter_fixed_point( +def free_support_barycenter_generic_costs( X_init, Y_list, b_list, diff --git a/ot/mapping.py b/ot/mapping.py index 1ec55cb95..d2a05809c 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -7,7 +7,7 @@ use it you need to explicitly import :mod:`ot.mapping` """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # # License: MIT License From 3e8421eb6dca94900bbca636a3594ff413cf5925 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 11:35:53 +0100 Subject: [PATCH 15/23] ot bar doc --- .../plot_barycenter_generic_cost.py | 10 +- ot/lp/_barycenter_solvers.py | 100 ++++++++++++++---- 2 files changed, 87 insertions(+), 23 deletions(-) diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_barycenter_generic_cost.py index 14779fdff..3e5ba38fe 100644 --- a/examples/barycenters/plot_barycenter_generic_cost.py +++ b/examples/barycenters/plot_barycenter_generic_cost.py @@ -6,9 +6,9 @@ This example illustrates the computation of an Optimal Transport for a ground cost that is not a power of a norm. We take the example of ground costs -:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear) +:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear) projection onto a circle k. This is an example of the fixed-point barycenter -solver introduced in [74] which generalises [20]. +solver introduced in [74] which generalises [20] and [43]. The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in \mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over @@ -22,6 +22,8 @@ Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning +[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + """ # Author: Eloi Tanguy @@ -147,8 +149,8 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): b_list, cost_list, B, - max_its=fixed_point_its, - stop_threshold=stop_threshold, + numItermax=fixed_point_its, + stopThr=stop_threshold, ) # %% Plot Barycenter (Iteration 10) diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index e45092caa..a04d4de05 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -430,33 +430,78 @@ class StoppingCriterionReached(Exception): def free_support_barycenter_generic_costs( X_init, - Y_list, - b_list, + measure_locations, + measure_weights, cost_list, B, - max_its=5, - stop_threshold=1e-5, + numItermax=5, + stopThr=1e-5, log=False, ): - """ - Solves the OT barycenter problem using the fixed point algorithm, iterating - the function B on plans between the current barycentre and the measures. + r""" + Solves the OT barycenter problem for generic costs using the fixed point + algorithm, iterating the ground barycenter function B on transport plans + between the current barycentre and the measures. + + The problem finds an optimal barycenter support `X` of given size (n, d) + (enforced by the initialisation), minimising a sum of pairwise transport + costs for the costs :math:`c_k`: + + .. math:: + \min_{X} \sum_{k=1}^K \mathcal{T}_{c_k}(X, a, Y_k, b_k), + + where: + + - :math:`X` (n, d) is the barycentre support, + - :math:`a` (n) is the (fixed) barycentre weights, + - :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`), + - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`), + - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix) + - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`: + + .. math:: + \mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F + + s.t. \ \pi \mathbf{1} = \mathbf{a} + + \pi^T \mathbf{1} = \mathbf{b_k} + + \pi \geq 0 + + in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k, + c_k(X, Y_k))`. + + The algorithm requires a given ground barycentre function `B` which computes + a solution of the following minimisation problem given :math:`(y_1, \cdots, + y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`: + + .. math:: + B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k), + + where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points + :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times + \cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to + this function, and for certain costs it can be computed explicitly of + through a numerical solver. + + This function implements [74] Algorithm 2, which generalises [20] and [43] + to general costs and includes convergence guarantees, including for discrete measures. Parameters ---------- X_init : array-like Array of shape (n, d) representing initial barycentre points. - Y_list : list of array-like + measure_locations : list of array-like List of K arrays of measure positions, each of shape (m_k, d_k). - b_list : list of array-like + measure_weights : list of array-like List of K arrays of measure weights, each of shape (m_k). cost_list : list of callable - List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k). + List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`. B : callable - Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre. - max_its : int, optional + Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre. + numItermax : int, optional Maximum number of iterations (default is 5). - stop_threshold : float, optional + stopThr : float, optional If the iterations move less than this, terminate (default is 1e-5). log : bool, optional Whether to return the log dictionary (default is False). @@ -468,9 +513,25 @@ def free_support_barycenter_generic_costs( log_dict : list of array-like, optional log containing the exit status, list of iterations and list of displacements if log is True. + + .. _references-free-support-barycenter-generic-costs: + + References + ---------- + .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) + + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + See Also + -------- + ot.lp.free_support_barycenter : Free support solver for the case where + :math:`c_k(x,y) = \|x-y\|_2^2`. + ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. """ - nx = get_backend(X_init, Y_list[0]) - K = len(Y_list) + nx = get_backend(X_init, measure_locations[0]) + K = len(measure_locations) n = X_init.shape[0] a = nx.ones(n) / n X_list = [X_init] if log else [] # store the iterations @@ -479,13 +540,14 @@ def free_support_barycenter_generic_costs( exit_status = "Unknown" try: - for _ in range(max_its): + for _ in range(numItermax): pi_list = [ # compute the pairwise transport plans - emd(a, b_list[k], cost_list[k](X, Y_list[k])) for k in range(K) + emd(a, measure_weights[k], cost_list[k](X, measure_locations[k])) + for k in range(K) ] Y_perm = [] for k in range(K): # compute barycentric projections - Y_perm.append(n * pi_list[k] @ Y_list[k]) + Y_perm.append(n * pi_list[k] @ measure_locations[k]) X_next = B(Y_perm) if log: @@ -498,7 +560,7 @@ def free_support_barycenter_generic_costs( if log: dX_list.append(dX) - if dX < stop_threshold: + if dX < stopThr: exit_status = "Stationary Point" raise StoppingCriterionReached From ccf608a19e515b8f3b664792532f6c1b5136ca5f Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 15:08:00 +0100 Subject: [PATCH 16/23] doc fixes + ot bar coverage --- .../plot_barycenter_generic_cost.py | 46 +++++---- ot/lp/_barycenter_solvers.py | 61 +++++++----- test/test_ot.py | 95 +++++++++++++++++++ 3 files changed, 161 insertions(+), 41 deletions(-) diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_barycenter_generic_cost.py index 3e5ba38fe..e5e5af73a 100644 --- a/examples/barycenters/plot_barycenter_generic_cost.py +++ b/examples/barycenters/plot_barycenter_generic_cost.py @@ -10,19 +10,20 @@ projection onto a circle k. This is an example of the fixed-point barycenter solver introduced in [74] which generalises [20] and [43]. -The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in -\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over +The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in +\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over :math:`x` with Pytorch. -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing -Barycentres of Measures for Generic Transport -Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +Barycentres of Measures for Generic Transport Costs. +arXiv preprint 2501.04016 (2024) -[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein -Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International -Conference in Machine Learning +[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein +Barycenters. InternationalConference in Machine Learning -[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. +[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in +Wasserstein space. Journal of Mathematical Analysis and Applications 441.2 +(2016): 744-762. """ @@ -32,7 +33,8 @@ # sphinx_gallery_thumbnail_number = 1 -# %% Generate data +# %% +# Generate data import torch from torch.optim import Adam from ot.utils import dist @@ -43,7 +45,7 @@ torch.manual_seed(42) -n = 100 # number of points of the of the barycentre +n = 200 # number of points of the of the barycentre d = 2 # dimensions of the original measure K = 4 # number of measures to barycentre m = 50 # number of points of the measures @@ -82,7 +84,8 @@ def proj_circle(X, origin, radius): Y_list.append(P_list[k](X_temp)) -# %% Define costs and ground barycenter function +# %% +# Define costs and ground barycenter function # cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a # (n, n_k) matrix of costs def c1(x, y): @@ -140,25 +143,30 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): return x -# %% Compute the barycenter measure -fixed_point_its = 10 +# %% +# Compute the barycenter measure +fixed_point_its = 3 X_init = torch.rand(n, d) X_bar = free_support_barycenter_generic_costs( - X_init, Y_list, b_list, + X_init, cost_list, B, numItermax=fixed_point_its, stopThr=stop_threshold, ) -# %% Plot Barycenter (Iteration 10) -alpha = 0.5 +# %% +# Plot Barycenter (Iteration 3) +alpha = 0.4 +s = 80 labels = ["circle 1", "circle 2", "circle 3", "circle 4"] for Y, label in zip(Y_list, labels): - plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label) -plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha) + plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) +plt.scatter( + *(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s +) plt.axis("equal") plt.xlim(-0.3, 1.3) plt.ylim(-0.3, 1.3) diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index a04d4de05..445a996df 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -429,11 +429,12 @@ class StoppingCriterionReached(Exception): def free_support_barycenter_generic_costs( - X_init, measure_locations, measure_weights, + X_init, cost_list, B, + a=None, numItermax=5, stopThr=1e-5, log=False, @@ -441,7 +442,7 @@ def free_support_barycenter_generic_costs( r""" Solves the OT barycenter problem for generic costs using the fixed point algorithm, iterating the ground barycenter function B on transport plans - between the current barycentre and the measures. + between the current barycenter and the measures. The problem finds an optimal barycenter support `X` of given size (n, d) (enforced by the initialisation), minimising a sum of pairwise transport @@ -452,12 +453,13 @@ def free_support_barycenter_generic_costs( where: - - :math:`X` (n, d) is the barycentre support, - - :math:`a` (n) is the (fixed) barycentre weights, - - :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`), + - :math:`X` (n, d) is the barycenter support, + - :math:`a` (n) is the (fixed) barycenter weights, + - :math:`Y_k` (m_k, d_k) is the k-th measure support + (`measure_locations[k]`), - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`), - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix) - - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`: + - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`: .. math:: \mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F @@ -471,9 +473,10 @@ def free_support_barycenter_generic_costs( in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k, c_k(X, Y_k))`. - The algorithm requires a given ground barycentre function `B` which computes - a solution of the following minimisation problem given :math:`(y_1, \cdots, - y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`: + The algorithm requires a given ground barycenter function `B` which computes + (broadcasted of `n`) solutions of the following minimisation problem given + :math:`(Y_1, \cdots, Y_K) \in + \mathbb{R}^{n\times d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`: .. math:: B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k), @@ -482,23 +485,32 @@ def free_support_barycenter_generic_costs( :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times \cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to this function, and for certain costs it can be computed explicitly of - through a numerical solver. + through a numerical solver. The input function B takes a list of K arrays of + shape (n, d_k) and returns an array of shape (n, d). This function implements [74] Algorithm 2, which generalises [20] and [43] - to general costs and includes convergence guarantees, including for discrete measures. + to general costs and includes convergence guarantees, including for discrete + measures. Parameters ---------- - X_init : array-like - Array of shape (n, d) representing initial barycentre points. measure_locations : list of array-like List of K arrays of measure positions, each of shape (m_k, d_k). measure_weights : list of array-like List of K arrays of measure weights, each of shape (m_k). + X_init : array-like + Array of shape (n, d) representing initial barycenter points. cost_list : list of callable - List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`. + List of K cost functions :math:`c_k: \mathbb{R}^{n\times + d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times + m_k}`. B : callable - Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre. + Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays + of shape (n\times d_K), computing the ground barycenters (broadcasted + over n). + a : array-like, optional + Array of shape (n,) representing weights of the barycenter + measure.Defaults to uniform. numItermax : int, optional Maximum number of iterations (default is 5). stopThr : float, optional @@ -509,7 +521,7 @@ def free_support_barycenter_generic_costs( Returns ------- X : array-like - Array of shape (n, d) representing barycentre points. + Array of shape (n, d) representing barycenter points. log_dict : list of array-like, optional log containing the exit status, list of iterations and list of displacements if log is True. @@ -518,22 +530,27 @@ def free_support_barycenter_generic_costs( References ---------- - .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) + .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing + barycenters of Measures for Generic Transport Costs. arXiv preprint + 2501.04016 (2024) - .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein + barycenters." International Conference on Machine Learning. 2014. - .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to + barycenters in Wasserstein space." Journal of Mathematical Analysis and + Applications 441.2 (2016): 744-762. See Also -------- - ot.lp.free_support_barycenter : Free support solver for the case where - :math:`c_k(x,y) = \|x-y\|_2^2`. + ot.lp.free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. """ nx = get_backend(X_init, measure_locations[0]) K = len(measure_locations) n = X_init.shape[0] - a = nx.ones(n) / n + if a is None: + a = nx.ones(n, type_as=X_init) / n X_list = [X_init] if log else [] # store the iterations X = X_init dX_list = [] # store the displacement squared norms diff --git a/test/test_ot.py b/test/test_ot.py index f84f8773a..4916d71aa 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -13,6 +13,8 @@ from ot.datasets import make_1D_gauss as gauss from ot.backend import torch, tf +# import ot.lp._barycenter_solvers # TODO: remove this import + def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch @@ -395,6 +397,99 @@ def test_generalised_free_support_barycenter_backends(nx): np.testing.assert_allclose(Y, nx.to_numpy(Y2)) +def test_free_support_barycenter_generic_costs(): + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] + + X_init = np.array([-12.0]).reshape((1, 1)) + + # obvious barycenter location between two Diracs + bar_locations = np.array([0.0]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def B(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, measures_weights, X_init, cost_list, B + ) + + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + # test with log and specific weights + X2, log = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + B, + a=ot.unif(1), + log=True, + ) + + assert "X_list" in log + assert "exit_status" in log + assert "dX_list" in log + + np.testing.assert_allclose(X, X2, rtol=1e-5, atol=1e-7) + + # test with one iteration for Max Iterations Reached + X3, log2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + B, + numItermax=1, + log=True, + ) + assert log2["exit_status"] == "Max iterations reached" + + +def test_free_support_barycenter_generic_costs_backends(nx): + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] + X_init = np.array([-12.0]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def B(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, measures_weights, X_init, cost_list, B + ) + + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) + X_init2 = nx.from_numpy(X_init) + + X2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations2, measures_weights2, X_init2, cost_list, B + ) + + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + @pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] From 37b9c80cad43f3b71768a265a4c57ef57734e06c Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 15:46:20 +0100 Subject: [PATCH 17/23] python 3.13 in test workflow + added ggmot barycenter (WIP) --- .github/workflows/build_tests.yml | 2 +- ot/gmm.py | 114 +++++++++++++++++++++++++++++- 2 files changed, 114 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 4356daa2b..52b4e1d99 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -47,7 +47,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12, "3.13"] steps: - uses: actions/checkout@v4 diff --git a/ot/gmm.py b/ot/gmm.py index d99d4e5db..bf4e700d3 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -13,7 +13,7 @@ from .lp import emd2, emd import numpy as np from .utils import dist -from .gaussian import bures_wasserstein_mapping +from .gaussian import bures_wasserstein_mapping, bures_wasserstein_barycenter def gaussian_logpdf(x, m, C): @@ -440,3 +440,115 @@ def Tk0k1(k0, k1): ] ) return nx.sum(mat, axis=(0, 1)) + + +def solve_gmm_barycenter_fixed_point( + means, + covs, + means_list, + covs_list, + b_list, + weights, + max_its=300, + log=False, + barycentric_proj_method="euclidean", +): + r""" + Solves the GMM OT barycenter problem using the fixed point algorithm. + + Parameters + ---------- + means : array-like + Initial (n, d) GMM means. + covs : array-like + Initial (n, d, d) GMM covariances. + means_list : list of array-like + List of K (m_k, d) GMM means. + covs_list : list of array-like + List of K (m_k, d, d) GMM covariances. + b_list : list of array-like + List of K (m_k) arrays of weights. + weights : array-like + Array (K,) of the barycentre coefficients. + max_its : int, optional + Maximum number of iterations (default is 300). + log : bool, optional + Whether to return the list of iterations (default is False). + barycentric_proj_method : str, optional + Method to project the barycentre weights: 'euclidean' (default) or 'bures'. + + Returns + ------- + means : array-like + (n, d) barycentre GMM means. + covs : array-like + (n, d, d) barycentre GMM covariances. + log_dict : dict, optional + Dictionary containing the list of iterations if log is True. + """ + nx = get_backend(means, covs[0], means_list[0], covs_list[0]) + K = len(means_list) + n = means.shape[0] + d = means.shape[1] + means_its = [means.copy()] + covs_its = [covs.copy()] + a = nx.ones(n, type_as=means) / n + + for _ in range(max_its): + pi_list = [ + gmm_ot_plan(means, means_list[k], covs, covs_list[k], a, b_list[k]) + for k in range(K) + ] + + means_selection, covs_selection = None, None + # in the euclidean case, the selection of Gaussians from each K sources + # comes from a barycentric projection is a convex combination of the + # selected means and covariances, which can be computed without a + # for loop on i + if barycentric_proj_method == "euclidean": + means_selection = nx.zeros((n, K, d), type_as=means) + covs_selection = nx.zeros((n, K, d, d), type_as=means) + + for k in range(K): + means_selection[:, k, :] = n * pi_list[k] @ means_list[k] + covs_selection[:, k, :, :] = ( + nx.einsum("ij,jab->iab", pi_list[k], covs_list[k]) * n + ) + + # each component i of the barycentre will be a Bures barycentre of the + # selected components of the K GMMs. In the 'bures' barycentric + # projection option, the selected components are also Bures barycentres. + for i in range(n): + # means_slice_i (K, d) is the selected means, each comes from a + # Gaussian barycentre along the disintegration of pi_k at i + # covs_slice_i (K, d, d) are the selected covariances + means_selection_i = [] + covs_selection_i = [] + + # use previous computation (convex combination) + if barycentric_proj_method == "euclidean": + means_selection_i = means_selection[i] + covs_selection_i = covs_selection[i] + + # compute Bures barycentre of the selected components + elif barycentric_proj_method == "bures": + w = (1 / a[i]) * pi_list[k][i, :] + for k in range(K): + m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w) + means_selection_i.append(m) + covs_selection_i.append(C) + + else: + raise ValueError("Unknown barycentric_proj_method") + + means[i], covs[i] = bures_wasserstein_barycenter( + means_selection_i, covs_selection_i, weights + ) + + if log: + means_its.append(means.copy()) + covs_its.append(covs.copy()) + + if log: + return means, covs, {"means_its": means_its, "covs_its": covs_its} + return means, covs From a20d3f0656e0e64c0dc4b7a74e94cc9a407c9bd9 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 16:06:43 +0100 Subject: [PATCH 18/23] fixed github action file --- .github/workflows/build_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 52b4e1d99..a8e27b323 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -47,7 +47,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.9", "3.10", "3.11", "3.12, "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 From 0b6217b00188f4f01bc80f5de7ba838e039cb39e Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 17:19:56 +0100 Subject: [PATCH 19/23] ot bar doc + test coverage --- .github/workflows/build_tests.yml | 2 +- ot/gmm.py | 103 ++++++++++++++++++++---------- ot/lp/_barycenter_solvers.py | 4 +- test/test_gmm.py | 54 +++++++++++++++- 4 files changed, 124 insertions(+), 39 deletions(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index a8e27b323..4356daa2b 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -47,7 +47,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 diff --git a/ot/gmm.py b/ot/gmm.py index bf4e700d3..214720d1e 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -442,36 +442,50 @@ def Tk0k1(k0, k1): return nx.sum(mat, axis=(0, 1)) -def solve_gmm_barycenter_fixed_point( - means, - covs, +def gmm_barycenter_fixed_point( means_list, covs_list, - b_list, + w_list, + means_init, + covs_init, weights, - max_its=300, + w_bar=None, + iterations=100, log=False, barycentric_proj_method="euclidean", ): r""" - Solves the GMM OT barycenter problem using the fixed point algorithm. + Solves the Gaussian Mixture Model OT barycenter problem (defined in [69]) + using the fixed point algorithm (proposed in [74]). The + weights of the barycenter are not optimized, and stay the same as the input + `w_list` or are initialized to uniform. + + The algorithm uses barycentric projections of GMM-OT plans, and these can be + computed either through Bures Barycenters (slow but accurate, + barycentric_proj_method='bures') or by convex combination (fast, + barycentric_proj_method='euclidean', default). + + This is a special case of the generic free-support barycenter solver + `ot.lp.free_support_barycenter_generic_costs`. Parameters ---------- - means : array-like - Initial (n, d) GMM means. - covs : array-like - Initial (n, d, d) GMM covariances. means_list : list of array-like List of K (m_k, d) GMM means. covs_list : list of array-like List of K (m_k, d, d) GMM covariances. - b_list : list of array-like + w_list : list of array-like List of K (m_k) arrays of weights. + means_init : array-like + Initial (n, d) GMM means. + covs_init : array-like + Initial (n, d, d) GMM covariances. weights : array-like Array (K,) of the barycentre coefficients. - max_its : int, optional - Maximum number of iterations (default is 300). + w_bar : array-like, optional + Initial weights (n) of the barycentre GMM. If None, initialized to uniform. + iterations : int, optional + Number of iterations (default is 100). log : bool, optional Whether to return the list of iterations (default is False). barycentric_proj_method : str, optional @@ -485,30 +499,46 @@ def solve_gmm_barycenter_fixed_point( (n, d, d) barycentre GMM covariances. log_dict : dict, optional Dictionary containing the list of iterations if log is True. + + References + ---------- + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. + + .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) + + See Also + -------- + ot.lp.free_support_barycenter_generic_costs : Compute barycenter of measures for generic transport costs. """ - nx = get_backend(means, covs[0], means_list[0], covs_list[0]) + nx = get_backend( + means_init, covs_init, means_list[0], covs_list[0], w_list[0], weights + ) K = len(means_list) - n = means.shape[0] - d = means.shape[1] - means_its = [means.copy()] - covs_its = [covs.copy()] - a = nx.ones(n, type_as=means) / n + n = means_init.shape[0] + d = means_init.shape[1] + means_its = [nx.copy(means_init)] + covs_its = [nx.copy(covs_init)] + means, covs = means_init, covs_init + + if w_bar is None: + w_bar = nx.ones(n, type_as=means) / n - for _ in range(max_its): + for _ in range(iterations): pi_list = [ - gmm_ot_plan(means, means_list[k], covs, covs_list[k], a, b_list[k]) + gmm_ot_plan(means, means_list[k], covs, covs_list[k], w_bar, w_list[k]) for k in range(K) ] + # filled in the euclidean case means_selection, covs_selection = None, None + # in the euclidean case, the selection of Gaussians from each K sources - # comes from a barycentric projection is a convex combination of the - # selected means and covariances, which can be computed without a - # for loop on i + # comes from a barycentric projection: it is a convex combination of the + # selected means and covariances, which can be computed without a + # for loop on i = 0, ..., n -1 if barycentric_proj_method == "euclidean": means_selection = nx.zeros((n, K, d), type_as=means) covs_selection = nx.zeros((n, K, d, d), type_as=means) - for k in range(K): means_selection[:, k, :] = n * pi_list[k] @ means_list[k] covs_selection[:, k, :, :] = ( @@ -519,24 +549,27 @@ def solve_gmm_barycenter_fixed_point( # selected components of the K GMMs. In the 'bures' barycentric # projection option, the selected components are also Bures barycentres. for i in range(n): - # means_slice_i (K, d) is the selected means, each comes from a + # means_selection_i (K, d) is the selected means, each comes from a # Gaussian barycentre along the disintegration of pi_k at i - # covs_slice_i (K, d, d) are the selected covariances - means_selection_i = [] - covs_selection_i = [] + # covs_selection_i (K, d, d) are the selected covariances + means_selection_i = None + covs_selection_i = None # use previous computation (convex combination) if barycentric_proj_method == "euclidean": means_selection_i = means_selection[i] covs_selection_i = covs_selection[i] - # compute Bures barycentre of the selected components + # compute Bures barycentre of certain components to get the + # selection at i elif barycentric_proj_method == "bures": - w = (1 / a[i]) * pi_list[k][i, :] + means_selection_i = nx.zeros((K, d), type_as=means) + covs_selection_i = nx.zeros((K, d, d), type_as=means) for k in range(K): + w = (1 / w_bar[i]) * pi_list[k][i, :] m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w) - means_selection_i.append(m) - covs_selection_i.append(C) + means_selection_i[k] = m + covs_selection_i[k] = C else: raise ValueError("Unknown barycentric_proj_method") @@ -546,8 +579,8 @@ def solve_gmm_barycenter_fixed_point( ) if log: - means_its.append(means.copy()) - covs_its.append(covs.copy()) + means_its.append(nx.copy(means)) + covs_its.append(nx.copy(covs)) if log: return means, covs, {"means_its": means_its, "covs_its": covs_its} diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 445a996df..9589121bd 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -435,7 +435,7 @@ def free_support_barycenter_generic_costs( cost_list, B, a=None, - numItermax=5, + numItermax=100, stopThr=1e-5, log=False, ): @@ -512,7 +512,7 @@ def free_support_barycenter_generic_costs( Array of shape (n,) representing weights of the barycenter measure.Defaults to uniform. numItermax : int, optional - Maximum number of iterations (default is 5). + Maximum number of iterations (default is 100). stopThr : float, optional If the iterations move less than this, terminate (default is 1e-5). log : bool, optional diff --git a/test/test_gmm.py b/test/test_gmm.py index 5f1a92965..629a68d57 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -1,6 +1,6 @@ """Tests for module gaussian""" -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # Julie Delon # @@ -17,6 +17,7 @@ gmm_ot_plan, gmm_ot_apply_map, gmm_ot_plan_density, + gmm_barycenter_fixed_point, ) try: @@ -193,3 +194,54 @@ def test_gmm_ot_plan_density(nx): with pytest.raises(AssertionError): gmm_ot_plan_density(x[:, 1:], y, m_s, m_t, C_s, C_t, w_s, w_t) + + +@pytest.skip_backend("tf") # skips because of array assignment +@pytest.skip_backend("jax") +def test_gmm_barycenter_fixed_point(nx): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx) + means_list = [m_s, m_t] + covs_list = [C_s, C_t] + w_list = [w_s, w_t] + n_iter = 3 + n = m_s.shape[0] # number of components of barycenter + means_init = m_s + covs_init = C_s + weights = nx.ones(2, type_as=m_s) / 2 # barycenter coefficients + + # with euclidean barycentric projections + means, covs = gmm_barycenter_fixed_point( + means_list, covs_list, w_list, means_init, covs_init, weights, iterations=n_iter + ) + + # with bures barycentric projections and assigned weights to uniform + means_bures_proj, covs_bures_proj, log = gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + iterations=n_iter, + w_bar=nx.ones(n, type_as=m_s) / n, + barycentric_proj_method="bures", + log=True, + ) + + assert "means_its" in log + assert "covs_its" in log + + assert np.allclose(means, means_bures_proj, atol=1e-6) + assert np.allclose(covs, covs_bures_proj, atol=1e-6) + + with pytest.raises(ValueError): + gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + iterations=n_iter, + barycentric_proj_method="unknown", + ) From 21bf86b944f2ce6cb71f381718c50095ca485850 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 17:52:11 +0100 Subject: [PATCH 20/23] examples: ot bar with projections onto circles + gmm ot bar --- README.md | 4 +- ...t_free_support_barycenter_generic_cost.py} | 8 +- examples/barycenters/plot_gmm_barycenter.py | 144 ++++++++++++++++++ 3 files changed, 149 insertions(+), 7 deletions(-) rename examples/barycenters/{plot_barycenter_generic_cost.py => plot_free_support_barycenter_generic_cost.py} (96%) create mode 100644 examples/barycenters/plot_gmm_barycenter.py diff --git a/README.md b/README.md index 9a8e5b371..9266c99c6 100644 --- a/README.md +++ b/README.md @@ -392,6 +392,4 @@ Artificial Intelligence. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing -Barycentres of Measures for Generic Transport -Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py similarity index 96% rename from examples/barycenters/plot_barycenter_generic_cost.py rename to examples/barycenters/plot_free_support_barycenter_generic_cost.py index e5e5af73a..55a75b157 100644 --- a/examples/barycenters/plot_barycenter_generic_cost.py +++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py @@ -4,8 +4,8 @@ OT Barycenter with Generic Costs Demo ===================================== -This example illustrates the computation of an Optimal Transport for a ground -cost that is not a power of a norm. We take the example of ground costs +This example illustrates the computation of an Optimal Transport Barycenter for +a ground cost that is not a power of a norm. We take the example of ground costs :math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear) projection onto a circle k. This is an example of the fixed-point barycenter solver introduced in [74] which generalises [20] and [43]. @@ -15,8 +15,8 @@ :math:`x` with Pytorch. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing -Barycentres of Measures for Generic Transport Costs. -arXiv preprint 2501.04016 (2024) +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 +(2024) [20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein Barycenters. InternationalConference in Machine Learning diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py new file mode 100644 index 000000000..07792c0dd --- /dev/null +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +""" +===================================== +Gaussian Mixture Model OT Barycenters +===================================== + +This example illustrates the computation of a barycenter between Gaussian +Mixtures in the sense of GMM-OT [69]. This computation is done using the +fixed-point method for OT barycenters with generic costs [74], for which POT +provides a general solver, and a specific GMM solver. Note that this is a +'free-support' method, implying that the number of components of the barycenter +GMM and their weights are fixed. + +The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over +the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the +Bures-Wasserstein manifold), and to compute barycenters with respect to the +2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a +gaussian mixture is a finite combination of Diracs on specific gaussians, and +two mixtures are compared with the 2-Wasserstein distance on this space with +ground cost the squared Bures distance between gaussians. + +[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space +of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. + +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 +(2024) + +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% +# Generate data +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Ellipse +import ot +from ot.gmm import gmm_barycenter_fixed_point + + +K = 3 # number of GMMs +d = 2 # dimension +n = 6 # number of components of the desired barycenter + + +def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2): + rng = np.random.RandomState(seed=seed) + means = rng.randn(K, d) + P = rng.randn(K, d, d) * cov_scale + # C[k] = P[k] @ P[k]^T + min_cov_eig * I + covariances = np.einsum("kab,kcb->kac", P, P) + covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)]) + weights = rng.random(K) + weights /= np.sum(weights) + return means, covariances, weights + + +m_list = [5, 6, 7] # number of components in each GMM +offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])] +means_list = [] # list of means for each GMM +covs_list = [] # list of covariances for each GMM +w_list = [] # list of weights for each GMM + +# generate GMMs +for k in range(K): + means, covs, b = get_random_gmm( + m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5 + ) + means = means / 2 + offsets[k][None, :] + means_list.append(means) + covs_list.append(covs) + w_list.append(b) + +# %% +# Compute the barycenter using the fixed-point method +init_means, init_covs, _ = get_random_gmm(n, d, seed=0) +weights = ot.unif(K) # barycenter coefficients +means_bar, covs_bar, log = gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + init_means, + init_covs, + weights, + iterations=3, + log=True, +) + + +# %% +# Define plotting functions + + +# draw a covariance ellipse +def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None): + def eigsorted(cov): + vals, vecs = np.linalg.eigh(cov) + order = vals.argsort()[::-1].copy() + return vals[order], vecs[:, order] + + vals, vecs = eigsorted(C) + theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) + w, h = 2 * nstd * np.sqrt(vals) + ell = Ellipse( + xy=(mu[0], mu[1]), + width=w, + height=h, + alpha=alpha, + angle=theta, + facecolor=color, + edgecolor=color, + label=label, + fill=True, + ) + if ax is None: + ax = plt.gca() + ax.add_artist(ell) + + +# draw a gmm as a set of ellipses with weights shown in alpha value +def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): + for k in range(ms.shape[0]): + draw_cov( + ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax + ) + + +# %% +# Plot the results +fig, ax = plt.subplots(figsize=(6, 6)) +axis = [-4, 4, -2, 6] +ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16) +for k in range(K): + draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax) +draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax) +ax.axis(axis) +ax.axis("off") + +# %% From 0820e513e3415a1aa03abb6cd6a9acb27a7096d9 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 21 Jan 2025 18:03:59 +0100 Subject: [PATCH 21/23] releases + readme + docs update --- README.md | 2 ++ RELEASES.md | 3 ++- examples/barycenters/plot_gmm_barycenter.py | 2 +- ot/lp/_barycenter_solvers.py | 27 ++++++++++++--------- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 9266c99c6..48a4a87fe 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,8 @@ POT provides the following generic OT solvers (links to examples): * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. +* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [74] +* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 74] POT provides the following Machine Learning related solvers: diff --git a/RELEASES.md b/RELEASES.md index ff8496bef..add09378c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,7 +8,8 @@ - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) - Implement fixed-point solver for OT barycenters with generic cost functions - (generalizes `ot.lp.free_support_barycenter`). (PR #715) + (generalizes `ot.lp.free_support_barycenter`), with example. (PR #715) +- Implement fixed-point solver for barycenters between GMMs (PR #715), with example. #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py index 07792c0dd..84d0ee638 100644 --- a/examples/barycenters/plot_gmm_barycenter.py +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -16,7 +16,7 @@ Bures-Wasserstein manifold), and to compute barycenters with respect to the 2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a gaussian mixture is a finite combination of Diracs on specific gaussians, and -two mixtures are compared with the 2-Wasserstein distance on this space with +two mixtures are compared with the 2-Wasserstein distance on this space, where ground cost the squared Bures distance between gaussians. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 9589121bd..5e53c66d2 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -458,7 +458,9 @@ def free_support_barycenter_generic_costs( - :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`), - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`), - - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix) + - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} + \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function + (which computes the pairwise cost matrix) - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`: .. math:: @@ -475,18 +477,19 @@ def free_support_barycenter_generic_costs( The algorithm requires a given ground barycenter function `B` which computes (broadcasted of `n`) solutions of the following minimisation problem given - :math:`(Y_1, \cdots, Y_K) \in - \mathbb{R}^{n\times d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`: + :math:`(Y_1, \cdots, Y_K) \in \mathbb{R}^{n\times + d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`: .. math:: B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k), where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points - :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times - \cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to - this function, and for certain costs it can be computed explicitly of - through a numerical solver. The input function B takes a list of K arrays of - shape (n, d_k) and returns an array of shape (n, d). + :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{n\times + d_1}\times \cdots\times\mathbb{R}^{n\times d_K} \longrightarrow + \mathbb{R}^{n\times d}` is an input to this function, and for certain costs + it can be computed explicitly of through a numerical solver. The input + function B takes a list of K arrays of shape (n, d_k) and returns an array + of shape (n, d). This function implements [74] Algorithm 2, which generalises [20] and [43] to general costs and includes convergence guarantees, including for discrete @@ -526,8 +529,6 @@ def free_support_barycenter_generic_costs( log containing the exit status, list of iterations and list of displacements if log is True. - .. _references-free-support-barycenter-generic-costs: - References ---------- .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing @@ -543,8 +544,10 @@ def free_support_barycenter_generic_costs( See Also -------- - ot.lp.free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|x-y\|_2^2`. - ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. + ot.lp.free_support_barycenter : Free support solver for the case where + :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter : + Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` + with :math:`P_k` linear. """ nx = get_backend(X_init, measure_locations[0]) K = len(measure_locations) From 6bd4af8b9c280798c2d5d8b617d611340589fdc7 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 12 Mar 2025 15:15:45 +0100 Subject: [PATCH 22/23] ref fix --- README.md | 4 ++-- .../plot_free_support_barycenter_generic_cost.py | 4 ++-- examples/barycenters/plot_gmm_barycenter.py | 6 ++---- ot/gmm.py | 4 ++-- ot/lp/_barycenter_solvers.py | 4 ++-- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 124c5d809..a7f1ff830 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,8 @@ POT provides the following generic OT solvers (links to examples): * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. -* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [74] -* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 74] +* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [76] +* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 76] POT provides the following Machine Learning related solvers: diff --git a/examples/barycenters/plot_free_support_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py index 55a75b157..47e2c9236 100644 --- a/examples/barycenters/plot_free_support_barycenter_generic_cost.py +++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py @@ -8,13 +8,13 @@ a ground cost that is not a power of a norm. We take the example of ground costs :math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear) projection onto a circle k. This is an example of the fixed-point barycenter -solver introduced in [74] which generalises [20] and [43]. +solver introduced in [76] which generalises [20] and [43]. The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in \mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over :math:`x` with Pytorch. -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py index 84d0ee638..f379a9914 100644 --- a/examples/barycenters/plot_gmm_barycenter.py +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -6,7 +6,7 @@ This example illustrates the computation of a barycenter between Gaussian Mixtures in the sense of GMM-OT [69]. This computation is done using the -fixed-point method for OT barycenters with generic costs [74], for which POT +fixed-point method for OT barycenters with generic costs [76], for which POT provides a general solver, and a specific GMM solver. Note that this is a 'free-support' method, implying that the number of components of the barycenter GMM and their weights are fixed. @@ -22,7 +22,7 @@ [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. -[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) @@ -140,5 +140,3 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax) ax.axis(axis) ax.axis("off") - -# %% diff --git a/ot/gmm.py b/ot/gmm.py index 214720d1e..a065c73b0 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -456,7 +456,7 @@ def gmm_barycenter_fixed_point( ): r""" Solves the Gaussian Mixture Model OT barycenter problem (defined in [69]) - using the fixed point algorithm (proposed in [74]). The + using the fixed point algorithm (proposed in [76]). The weights of the barycenter are not optimized, and stay the same as the input `w_list` or are initialized to uniform. @@ -504,7 +504,7 @@ def gmm_barycenter_fixed_point( ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. - .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) + .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) See Also -------- diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 61b4fce49..f803d23db 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -495,7 +495,7 @@ def free_support_barycenter_generic_costs( function B takes a list of K arrays of shape (n, d_k) and returns an array of shape (n, d). - This function implements [74] Algorithm 2, which generalises [20] and [43] + This function implements [76] Algorithm 2, which generalises [20] and [43] to general costs and includes convergence guarantees, including for discrete measures. @@ -535,7 +535,7 @@ def free_support_barycenter_generic_costs( References ---------- - .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing + .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) From 51722bf65f1be26a453f5602f07d1ecb4752c896 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 17 Mar 2025 19:54:14 +0100 Subject: [PATCH 23/23] implementation comments --- ot/lp/_barycenter_solvers.py | 133 +++++++++++++++++++++++------------ test/test_ot.py | 99 +++++++++++++++++++++++--- 2 files changed, 178 insertions(+), 54 deletions(-) diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index f803d23db..725af26c4 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -199,14 +199,12 @@ def free_support_barycenter( measures_weights : list of N (k_i,) array-like Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one representing the weights of each discrete input measure - X_init : (k,d) array-like Initialization of the support locations (on `k` atoms) of the barycenter b : (k,) array-like Initialization of the weights of the barycenter (non-negatives, sum to 1) weights : (N,) array-like Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - numItermax : int, optional Max number of iterations stopThr : float, optional @@ -219,13 +217,11 @@ def free_support_barycenter( If compiled with OpenMP, chooses the number of threads to parallelize. "max" selects the highest number possible. - Returns ------- X : (k,d) array-like Support locations (on k atoms) of the barycenter - .. _references-free-support-barycenter: References ---------- @@ -428,20 +424,20 @@ def generalized_free_support_barycenter( return Y -class StoppingCriterionReached(Exception): - pass - - def free_support_barycenter_generic_costs( measure_locations, measure_weights, X_init, cost_list, - B, + ground_bary=None, a=None, numItermax=100, stopThr=1e-5, log=False, + ground_bary_lr=1e-2, + ground_bary_numItermax=100, + ground_bary_stopThr=1e-5, + ground_bary_solver="SGD", ): r""" Solves the OT barycenter problem for generic costs using the fixed point @@ -507,14 +503,15 @@ def free_support_barycenter_generic_costs( List of K arrays of measure weights, each of shape (m_k). X_init : array-like Array of shape (n, d) representing initial barycenter points. - cost_list : list of callable + cost_list : list of callable or callable List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times - m_k}`. - B : callable + m_k}`. If cost_list is a single callable, the same cost is used K times. + ground_bary : callable or None, optional Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays of shape (n\times d_K), computing the ground barycenters (broadcasted - over n). + over n). If not provided, done with Adam on PyTorch (requires PyTorch + backend) a : array-like, optional Array of shape (n,) representing weights of the barycenter measure.Defaults to uniform. @@ -524,6 +521,16 @@ def free_support_barycenter_generic_costs( If the iterations move less than this, terminate (default is 1e-5). log : bool, optional Whether to return the log dictionary (default is False). + ground_bary_lr : float, optional + Learning rate for the ground barycenter solver (if auto is used). + ground_bary_numItermax : int, optional + Maximum number of iterations for the ground barycenter solver (if auto + is used). + ground_bary_stopThr : float, optional + Stop threshold for the ground barycenter solver (if auto is used). + ground_bary_solver : str, optional + Solver for auto ground bary solver (torch SGD or Adam). Default is + "SGD". Returns ------- @@ -549,49 +556,85 @@ def free_support_barycenter_generic_costs( See Also -------- ot.lp.free_support_barycenter : Free support solver for the case where - :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter : - Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` - with :math:`P_k` linear. + :math:`c_k(x,y) = \lambda_k\|x-y\|_2^2`. + ot.lp.generalized_free_support_barycenter : Free support solver for the case + where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. """ nx = get_backend(X_init, measure_locations[0]) K = len(measure_locations) n = X_init.shape[0] if a is None: a = nx.ones(n, type_as=X_init) / n + if callable(cost_list): # use the given cost for all K pairs + cost_list = [cost_list] * K + auto_ground_bary = False + + if ground_bary is None: + auto_ground_bary = True + assert str(nx) == "torch", ( + f"Backend {str(nx)} is not compatible with ground_bary=None, it" + "must be provided if not using PyTorch backend" + ) + try: + import torch + from torch.optim import Adam, SGD + + def ground_bary(y, x_init): + x = x_init.clone().detach().requires_grad_(True) + solver = Adam if ground_bary_solver == "Adam" else SGD + opt = solver([x], lr=ground_bary_lr) + for _ in range(ground_bary_numItermax): + x_prev = x.data.clone() + opt.zero_grad() + # inefficient cost computation but compatible + # with the choice of cost_list[k] giving the cost matrix + loss = torch.sum( + torch.stack( + [torch.diag(cost_list[k](x, y[k])) for k in range(K)] + ) + ) + loss.backward() + opt.step() + diff = torch.sum((x.data - x_prev) ** 2) + if diff < ground_bary_stopThr: + break + return x.detach() + + except ImportError: + raise ImportError("PyTorch is required to use ground_bary=None") + X_list = [X_init] if log else [] # store the iterations X = X_init dX_list = [] # store the displacement squared norms - exit_status = "Unknown" - - try: - for _ in range(numItermax): - pi_list = [ # compute the pairwise transport plans - emd(a, measure_weights[k], cost_list[k](X, measure_locations[k])) - for k in range(K) - ] - Y_perm = [] - for k in range(K): # compute barycentric projections - Y_perm.append(n * pi_list[k] @ measure_locations[k]) - X_next = B(Y_perm) - - if log: - X_list.append(X_next) + exit_status = "Max iterations reached" + + for _ in range(numItermax): + pi_list = [ # compute the pairwise transport plans + emd(a, measure_weights[k], cost_list[k](X, measure_locations[k])) + for k in range(K) + ] + Y_perm = [] + for k in range(K): # compute barycentric projections + Y_perm.append(n * pi_list[k] @ measure_locations[k]) + if auto_ground_bary: # use previous position as initialization + X_next = ground_bary(Y_perm, X) + else: + X_next = ground_bary(Y_perm) - # stationary criterion: move less than the threshold - dX = nx.sum((X - X_next) ** 2) - X = X_next + if log: + X_list.append(X_next) - if log: - dX_list.append(dX) + # stationary criterion: move less than the threshold + dX = nx.sum((X - X_next) ** 2) + X = X_next - if dX < stopThr: - exit_status = "Stationary Point" - raise StoppingCriterionReached + if log: + dX_list.append(dX) - exit_status = "Max iterations reached" - raise StoppingCriterionReached + if dX < stopThr: + exit_status = "Stationary Point" + break - except StoppingCriterionReached: - if log: - return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list} - return X + if log: + return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list} + return X diff --git a/test/test_ot.py b/test/test_ot.py index 4916d71aa..22612fa4a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -13,8 +13,6 @@ from ot.datasets import make_1D_gauss as gauss from ot.backend import torch, tf -# import ot.lp._barycenter_solvers # TODO: remove this import - def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch @@ -414,14 +412,14 @@ def cost(x, y): cost_list = [cost, cost] - def B(y): + def ground_bary(y): out = 0 for yk in y: out += yk / len(y) return out X = ot.lp.free_support_barycenter_generic_costs( - measures_locations, measures_weights, X_init, cost_list, B + measures_locations, measures_weights, X_init, cost_list, ground_bary ) np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) @@ -432,7 +430,7 @@ def B(y): measures_weights, X_init, cost_list, - B, + ground_bary, a=ot.unif(1), log=True, ) @@ -449,12 +447,95 @@ def B(y): measures_weights, X_init, cost_list, - B, + ground_bary, numItermax=1, log=True, ) assert log2["exit_status"] == "Max iterations reached" + # test with a single callable cost + X3, log3 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost, + ground_bary, + numItermax=1, + log=True, + ) + + # test with no ground_bary but in numpy: requires pytorch backend + with pytest.raises(AssertionError): + ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + numItermax=1, + ) + + +@pytest.mark.skipif(not torch, reason="No torch available") +def test_free_support_barycenter_generic_costs_auto_ground_bary(): + measures_locations = [ + torch.tensor([1.0]).reshape((1, 1)), + torch.tensor([2.0]).reshape((1, 1)), + ] + measures_weights = [torch.tensor([1.0]), torch.tensor([1.0])] + + X_init = torch.tensor([1.2]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def ground_bary(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=1, + ) + + X2, log2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + ground_bary_lr=1e-2, + ground_bary_stopThr=1e-20, + ground_bary_numItermax=50, + numItermax=10, + log=True, + ) + + np.testing.assert_allclose(X2.numpy(), X.numpy(), rtol=1e-4, atol=1e-4) + + X3 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + ground_bary_lr=1e-2, + ground_bary_stopThr=1e-20, + ground_bary_numItermax=50, + numItermax=10, + ground_bary_solver="Adam", + ) + + np.testing.assert_allclose(X2.numpy(), X3.numpy(), rtol=1e-3, atol=1e-3) + def test_free_support_barycenter_generic_costs_backends(nx): measures_locations = [ @@ -469,14 +550,14 @@ def cost(x, y): cost_list = [cost, cost] - def B(y): + def ground_bary(y): out = 0 for yk in y: out += yk / len(y) return out X = ot.lp.free_support_barycenter_generic_costs( - measures_locations, measures_weights, X_init, cost_list, B + measures_locations, measures_weights, X_init, cost_list, ground_bary ) measures_locations2 = nx.from_numpy(*measures_locations) @@ -484,7 +565,7 @@ def B(y): X_init2 = nx.from_numpy(X_init) X2 = ot.lp.free_support_barycenter_generic_costs( - measures_locations2, measures_weights2, X_init2, cost_list, B + measures_locations2, measures_weights2, X_init2, cost_list, ground_bary ) np.testing.assert_allclose(X, nx.to_numpy(X2))