Skip to content

[MRG] Fix None init plan in unbalanced lbfgs solvers #731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Backend implementation of `ot.dist` for (PR #701)
- Updated documentation Quickstart guide and User guide with new API (PR #726)
- Fix jax version for auto-grad (PR #732)
- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731)

#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
Expand Down
94 changes: 45 additions & 49 deletions ot/unbalanced/_lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
#
# License: MIT License

import warnings
import numpy as np
from scipy.optimize import minimize, Bounds

from ..backend import get_backend
from ..utils import list_to_array, get_parameter_pair
from ..utils import list_to_array, get_parameter_pair, fun_to_numpy


def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div="kl"):
Expand Down Expand Up @@ -46,9 +45,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div
Divergence used for regularization.
Can take three values: 'entropy' (negative entropy), or
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
of two calable functions returning the reg term and its derivative.
of two callable functions returning the reg term and its derivative.
Note that the callable functions should be able to handle Numpy arrays
and not tesors from the backend
and not tensors from the backend
regm_div: string, optional
Divergence to quantify the difference between the marginals.
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
Expand Down Expand Up @@ -206,26 +205,27 @@ def lbfgsb_unbalanced(
loss matrix
reg: float
regularization term >=0
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
reg_div: string, optional
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_div: string or pair of callable functions, optional (default = 'kl')
Divergence used for regularization.
Can take three values: 'entropy' (negative entropy), or
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
of two calable functions returning the reg term and its derivative.
of two callable functions returning the reg term and its derivative.
Note that the callable functions should be able to handle Numpy arrays
and not tesors from the backend
regm_div: string, optional
and not tensors from the backend, otherwise functions will be converted to Numpy
leading to a computational overhead.
regm_div: string, optional (default = 'kl')
Divergence to quantify the difference between the marginals.
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
G0: array-like (dim_a, dim_b)
Initialization of the transport matrix
G0: array-like (dim_a, dim_b), optional (default = None)
Initialization of the transport matrix. None corresponds to uniform product.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Expand Down Expand Up @@ -267,26 +267,14 @@ def lbfgsb_unbalanced(
ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
"""

# wrap the callable function to handle numpy arrays
if isinstance(reg_div, tuple):
f0, df0 = reg_div
try:
f0(G0)
df0(G0)
except BaseException:
warnings.warn(
"The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
)

def f(x):
return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0)))

def df(x):
return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0)))

reg_div = (f, df)
# test settings
regm_div = regm_div.lower()
if regm_div not in ["kl", "l2", "tv"]:
raise ValueError(
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
)

else:
if isinstance(reg_div, str):
reg_div = reg_div.lower()
if reg_div not in ["entropy", "kl", "l2"]:
raise ValueError(
Expand All @@ -295,16 +283,11 @@ def df(x):
)
)

regm_div = regm_div.lower()
if regm_div not in ["kl", "l2", "tv"]:
raise ValueError(
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
)

# convert all inputs to numpy arrays
reg_m1, reg_m2 = get_parameter_pair(reg_m)

M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)
nx = get_backend(M, a, b, G0)
M0 = M

dim_a, dim_b = M.shape
Expand All @@ -315,10 +298,22 @@ def df(x):
b = nx.ones(dim_b, type_as=M) / dim_b

# convert to numpy
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)
if nx.__name__ == "numpy": # remaining parameters which can be arrays
reg_m1, reg_m2, reg = nx.to_numpy(reg_m1, reg_m2, reg)
else:
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)

G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0)
c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c)

# potentially convert the callable function to handle numpy arrays
if isinstance(reg_div, tuple):
f0, df0 = reg_div
f = fun_to_numpy(f0, G0, nx, warn=True)
df = fun_to_numpy(df0, G0, nx, warn=True)

reg_div = (f, df)

_func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div)

res = minimize(
Expand Down Expand Up @@ -399,26 +394,27 @@ def lbfgsb_unbalanced2(
loss matrix
reg: float
regularization term >=0
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
reg_div: string, optional
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_div: string or pair of callable functions, optional (default = 'kl')
Divergence used for regularization.
Can take three values: 'entropy' (negative entropy), or
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
of two calable functions returning the reg term and its derivative.
of two callable functions returning the reg term and its derivative.
Note that the callable functions should be able to handle Numpy arrays
and not tesors from the backend
regm_div: string, optional
and not tensors from the backend, otherwise functions will be converted to Numpy
leading to a computational overhead.
regm_div: string, optional (default = 'kl')
Divergence to quantify the difference between the marginals.
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
G0: array-like (dim_a, dim_b)
Initialization of the transport matrix
G0: array-like (dim_a, dim_b), optional (default = None)
Initialization of the transport matrix. None corresponds to uniform product.
returnCost: string, optional (default = "linear")
If `returnCost` = "linear", then return the linear part of the unbalanced OT loss.
If `returnCost` = "total", then return the total unbalanced OT loss.
Expand Down
40 changes: 40 additions & 0 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,3 +1473,43 @@ def check_number_threads(numThreads):
'numThreads should either be "max" or a strictly positive integer'
)
return numThreads


def fun_to_numpy(fun, arr, nx, warn=True):
"""Convert a function to a numpy function.

Parameters
----------
fun : callable
The function to convert.
arr : array-like
The input to test the function. Can be from any backend.
nx : Backend
The backend to use for the conversion.
warn : bool, optional
Whether to raise a warning if the function is not compatible with numpy.
Default is True.
Returns
-------
fun_numpy : callable
The converted function.
"""
if arr is None:
raise ValueError("arr should not be None to test fun")

nx_arr = get_backend(arr)
if nx_arr.__name__ != "numpy":
arr = nx.to_numpy(arr)
try:
fun(arr)
return fun
except BaseException:
if warn:
warnings.warn(
"The callable function should be able to handle numpy arrays, a compatible function is created and comes with overhead"
)

def fun_numpy(x):
return nx.to_numpy(fun(nx.from_numpy(x)))

return fun_numpy
18 changes: 18 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,3 +731,21 @@ def test_exp_bures(nx):
# exp_\Lambda(log_\Lambda(Sigma)) = Sigma
Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d, type_as=T))
np.testing.assert_allclose(nx.to_numpy(Sigma), nx.to_numpy(Sigma_exp), atol=1e-5)


def test_fun_to_numpy(nx):
arr = np.arange(5)
arrb = nx.from_numpy(arr)

def fun(x): # backend function
return nx.sum(x)

fun_numpy = ot.utils.fun_to_numpy(fun, arrb, nx, warn=True)

res = nx.to_numpy(fun(arrb))
res_np = fun_numpy(arr)

np.testing.assert_allclose(res, res_np)

with pytest.raises(ValueError):
ot.utils.fun_to_numpy(fun, None, nx, warn=True)
Loading