Skip to content

[MRG] Fix None init plan in unbalanced lbfgs solvers #731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
82 changes: 44 additions & 38 deletions ot/unbalanced/_lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,26 +206,26 @@ def lbfgsb_unbalanced(
loss matrix
reg: float
regularization term >=0
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
reg_div: string, optional
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_div: string or pair of callable functions, optional (default = 'kl')
Divergence used for regularization.
Can take three values: 'entropy' (negative entropy), or
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
of two calable functions returning the reg term and its derivative.
of two callable functions returning the reg term and its derivative.
Note that the callable functions should be able to handle Numpy arrays
and not tesors from the backend
regm_div: string, optional
and not tensors from the backend
regm_div: string, optional (default = 'kl')
Divergence to quantify the difference between the marginals.
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
G0: array-like (dim_a, dim_b)
Initialization of the transport matrix
G0: array-like (dim_a, dim_b), optional (default = None)
Initialization of the transport matrix. None corresponds to uniform product.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Expand Down Expand Up @@ -267,26 +267,14 @@ def lbfgsb_unbalanced(
ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
"""

# wrap the callable function to handle numpy arrays
if isinstance(reg_div, tuple):
f0, df0 = reg_div
try:
f0(G0)
df0(G0)
except BaseException:
warnings.warn(
"The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
)

def f(x):
return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0)))

def df(x):
return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0)))

reg_div = (f, df)
# test settings
regm_div = regm_div.lower()
if regm_div not in ["kl", "l2", "tv"]:
raise ValueError(
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
)

else:
if isinstance(reg_div, str):
reg_div = reg_div.lower()
if reg_div not in ["entropy", "kl", "l2"]:
raise ValueError(
Expand All @@ -295,16 +283,11 @@ def df(x):
)
)

regm_div = regm_div.lower()
if regm_div not in ["kl", "l2", "tv"]:
raise ValueError(
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
)

# convert all inputs to numpy arrays
reg_m1, reg_m2 = get_parameter_pair(reg_m)

M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)
nx = get_backend(M, a, b, G0)
M0 = M

dim_a, dim_b = M.shape
Expand All @@ -315,10 +298,33 @@ def df(x):
b = nx.ones(dim_b, type_as=M) / dim_b

# convert to numpy
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)
if nx.__name__ == "numpy": # remaining parameters which can be arrays
reg_m1, reg_m2, reg = nx.to_numpy(reg_m1, reg_m2, reg)
else:
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)

G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0)
c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c)

# wrap the callable function to handle numpy arrays
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we create a function in utils that provide a check/wrapper around functions for numpy conversion? hat seems like something we might need?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done via ot.utils.fun_to_numpy. Could you please double check the implementation ?

if isinstance(reg_div, tuple):
f0, df0 = reg_div
try:
f0(G0)
df0(G0)
except BaseException:
warnings.warn(
"The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
)

def f(x):
return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0)))

def df(x):
return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0)))

reg_div = (f, df)

_func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div)

res = minimize(
Expand Down Expand Up @@ -411,9 +417,9 @@ def lbfgsb_unbalanced2(
Divergence used for regularization.
Can take three values: 'entropy' (negative entropy), or
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
of two calable functions returning the reg term and its derivative.
of two callable functions returning the reg term and its derivative.
Note that the callable functions should be able to handle Numpy arrays
and not tesors from the backend
and not tensors from the backend
regm_div: string, optional
Divergence to quantify the difference between the marginals.
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
Expand Down
Loading