diff --git a/RELEASES.md b/RELEASES.md index a24747fb7..d82b88b73 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -17,6 +17,7 @@ - Backend implementation of `ot.dist` for (PR #701) - Updated documentation Quickstart guide and User guide with new API (PR #726) - Fix jax version for auto-grad (PR #732) +- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index c4de87474..ea273c7db 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -9,12 +9,11 @@ # # License: MIT License -import warnings import numpy as np from scipy.optimize import minimize, Bounds from ..backend import get_backend -from ..utils import list_to_array, get_parameter_pair +from ..utils import list_to_array, get_parameter_pair, fun_to_numpy def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div="kl"): @@ -46,9 +45,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend + and not tensors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) @@ -206,26 +205,27 @@ def lbfgsb_unbalanced( loss matrix reg: float regularization term >=0 - c : array-like (dim_a, dim_b), optional (default = None) - Reference measure for the regularization. - If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 Marginal relaxation term: nonnegative (including 0) but cannot be infinity. If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array. - reg_div: string, optional + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + reg_div: string or pair of callable functions, optional (default = 'kl') Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend - regm_div: string, optional + and not tensors from the backend, otherwise functions will be converted to Numpy + leading to a computational overhead. + regm_div: string, optional (default = 'kl') Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) - G0: array-like (dim_a, dim_b) - Initialization of the transport matrix + G0: array-like (dim_a, dim_b), optional (default = None) + Initialization of the transport matrix. None corresponds to uniform product. numItermax : int, optional Max number of iterations stopThr : float, optional @@ -267,26 +267,14 @@ def lbfgsb_unbalanced( ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - # wrap the callable function to handle numpy arrays - if isinstance(reg_div, tuple): - f0, df0 = reg_div - try: - f0(G0) - df0(G0) - except BaseException: - warnings.warn( - "The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead" - ) - - def f(x): - return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) - - def df(x): - return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) - - reg_div = (f, df) + # test settings + regm_div = regm_div.lower() + if regm_div not in ["kl", "l2", "tv"]: + raise ValueError( + "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div) + ) - else: + if isinstance(reg_div, str): reg_div = reg_div.lower() if reg_div not in ["entropy", "kl", "l2"]: raise ValueError( @@ -295,16 +283,11 @@ def df(x): ) ) - regm_div = regm_div.lower() - if regm_div not in ["kl", "l2", "tv"]: - raise ValueError( - "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div) - ) - + # convert all inputs to numpy arrays reg_m1, reg_m2 = get_parameter_pair(reg_m) M, a, b = list_to_array(M, a, b) - nx = get_backend(M, a, b) + nx = get_backend(M, a, b, G0) M0 = M dim_a, dim_b = M.shape @@ -315,10 +298,22 @@ def df(x): b = nx.ones(dim_b, type_as=M) / dim_b # convert to numpy - a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg) + if nx.__name__ == "numpy": # remaining parameters which can be arrays + reg_m1, reg_m2, reg = nx.to_numpy(reg_m1, reg_m2, reg) + else: + a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg) + G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) + # potentially convert the callable function to handle numpy arrays + if isinstance(reg_div, tuple): + f0, df0 = reg_div + f = fun_to_numpy(f0, G0, nx, warn=True) + df = fun_to_numpy(df0, G0, nx, warn=True) + + reg_div = (f, df) + _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) res = minimize( @@ -399,26 +394,27 @@ def lbfgsb_unbalanced2( loss matrix reg: float regularization term >=0 - c : array-like (dim_a, dim_b), optional (default = None) - Reference measure for the regularization. - If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 Marginal relaxation term: nonnegative (including 0) but cannot be infinity. If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array. - reg_div: string, optional + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + reg_div: string or pair of callable functions, optional (default = 'kl') Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend - regm_div: string, optional + and not tensors from the backend, otherwise functions will be converted to Numpy + leading to a computational overhead. + regm_div: string, optional (default = 'kl') Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) - G0: array-like (dim_a, dim_b) - Initialization of the transport matrix + G0: array-like (dim_a, dim_b), optional (default = None) + Initialization of the transport matrix. None corresponds to uniform product. returnCost: string, optional (default = "linear") If `returnCost` = "linear", then return the linear part of the unbalanced OT loss. If `returnCost` = "total", then return the total unbalanced OT loss. diff --git a/ot/utils.py b/ot/utils.py index 1f24fa33f..551ccf7f4 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1473,3 +1473,43 @@ def check_number_threads(numThreads): 'numThreads should either be "max" or a strictly positive integer' ) return numThreads + + +def fun_to_numpy(fun, arr, nx, warn=True): + """Convert a function to a numpy function. + + Parameters + ---------- + fun : callable + The function to convert. + arr : array-like + The input to test the function. Can be from any backend. + nx : Backend + The backend to use for the conversion. + warn : bool, optional + Whether to raise a warning if the function is not compatible with numpy. + Default is True. + Returns + ------- + fun_numpy : callable + The converted function. + """ + if arr is None: + raise ValueError("arr should not be None to test fun") + + nx_arr = get_backend(arr) + if nx_arr.__name__ != "numpy": + arr = nx.to_numpy(arr) + try: + fun(arr) + return fun + except BaseException: + if warn: + warnings.warn( + "The callable function should be able to handle numpy arrays, a compatible function is created and comes with overhead" + ) + + def fun_numpy(x): + return nx.to_numpy(fun(nx.from_numpy(x))) + + return fun_numpy diff --git a/test/test_utils.py b/test/test_utils.py index 938fd6058..0b2769109 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -731,3 +731,21 @@ def test_exp_bures(nx): # exp_\Lambda(log_\Lambda(Sigma)) = Sigma Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d, type_as=T)) np.testing.assert_allclose(nx.to_numpy(Sigma), nx.to_numpy(Sigma_exp), atol=1e-5) + + +def test_fun_to_numpy(nx): + arr = np.arange(5) + arrb = nx.from_numpy(arr) + + def fun(x): # backend function + return nx.sum(x) + + fun_numpy = ot.utils.fun_to_numpy(fun, arrb, nx, warn=True) + + res = nx.to_numpy(fun(arrb)) + res_np = fun_numpy(arr) + + np.testing.assert_allclose(res, res_np) + + with pytest.raises(ValueError): + ot.utils.fun_to_numpy(fun, None, nx, warn=True)