Skip to content

Commit 6510029

Browse files
pnavarodevmotion
andauthored
Add mm_unbalanced function (#22)
* Add mm_unbalanced function * Update api.md * Update src/lib.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update src/lib.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update src/lib.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update src/lib.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update src/lib.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update src/lib.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update src/lib.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Add doctest in mm_unbalanced function --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent a3c1d24 commit 6510029

File tree

4 files changed

+67
-15
lines changed

4 files changed

+67
-15
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PythonOT"
22
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
33
authors = ["David Widmann"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"

docs/src/api.md

+1
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ PythonOT.Smooth.smooth_ot_dual
3535
sinkhorn_unbalanced
3636
sinkhorn_unbalanced2
3737
barycenter_unbalanced
38+
mm_unbalanced
3839
```

src/PythonOT.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ export emd,
1212
barycenter_unbalanced,
1313
sinkhorn_unbalanced,
1414
sinkhorn_unbalanced2,
15-
empirical_sinkhorn_divergence
15+
empirical_sinkhorn_divergence,
16+
mm_unbalanced
1617

1718
const pot = PyCall.PyNULL()
1819

src/lib.jl

+63-13
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,11 @@ julia> ν = [0.0, 1.0];
312312
313313
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
314314
315-
julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000)
315+
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
316316
3×2 Matrix{Float64}:
317-
0.0 0.499964
318-
0.0 0.200188
319-
0.0 0.29983
317+
0.0 0.5
318+
0.0 0.2002
319+
0.0 0.2998
320320
```
321321
322322
It is possible to provide multiple target marginals as columns of a matrix. In this case the
@@ -325,10 +325,10 @@ optimal transport costs are returned:
325325
```jldoctest sinkhorn_unbalanced
326326
julia> ν = [0.0 0.5; 1.0 0.5];
327327
328-
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=6)
328+
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
329329
2-element Vector{Float64}:
330-
0.949709
331-
0.449411
330+
0.9497
331+
0.4494
332332
```
333333
334334
See also: [`sinkhorn_unbalanced2`](@ref)
@@ -371,20 +371,19 @@ julia> ν = [0.0, 1.0];
371371
372372
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
373373
374-
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
375-
1-element Vector{Float64}:
376-
0.949709
374+
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
375+
0.9497
377376
```
378377
379378
It is possible to provide multiple target marginals as columns of a matrix:
380379
381380
```jldoctest sinkhorn_unbalanced2
382381
julia> ν = [0.0 0.5; 1.0 0.5];
383382
384-
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
383+
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
385384
2-element Vector{Float64}:
386-
0.949709
387-
0.449411
385+
0.9497
386+
0.4494
388387
```
389388
390389
See also: [`sinkhorn_unbalanced`](@ref)
@@ -516,3 +515,54 @@ Python function.
516515
function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...)
517516
return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...)
518517
end
518+
519+
"""
520+
mm_unbalanced(a, b, M, reg_m; reg=0, c=a*b', kwargs...)
521+
522+
Solve the unbalanced optimal transport problem and return the OT plan.
523+
The function solves the following optimization problem:
524+
525+
```math
526+
W = \\min_{\\gamma \\geq 0} \\langle \\gamma, M \\rangle_F +
527+
\\mathrm{reg_{m1}} \\cdot \\operatorname{div}(\\gamma \\mathbf{1}, a) +
528+
\\mathrm{reg_{m2}} \\cdot \\operatorname{div}(\\gamma^\\mathsf{T} \\mathbf{1}, b) +
529+
\\mathrm{reg} \\cdot \\operatorname{div}(\\gamma, c)
530+
```
531+
532+
where
533+
534+
- `M` is the metric cost matrix,
535+
- `a` and `b` are source and target unbalanced distributions,
536+
- `c` is a reference distribution for the regularization,
537+
- `reg_m` is the marginal relaxation term (if it is a scalar or an indexable object of length 1, then the same term is applied to both marginal relaxations), and
538+
- `reg` is a regularization term.
539+
540+
This function is a wrapper of the function
541+
[`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the
542+
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
543+
Python function.
544+
545+
# Examples
546+
547+
```jldoctest
548+
julia> a=[.5, .5];
549+
550+
julia> b=[.5, .5];
551+
552+
julia> M=[1. 36.; 9. 4.];
553+
554+
julia> round.(mm_unbalanced(a, b, M, 5, div="kl"), digits=2)
555+
2×2 Matrix{Float64}:
556+
0.45 0.0
557+
0.0 0.34
558+
559+
julia> round.(mm_unbalanced(a, b, M, 5, div="l2"), digits=2)
560+
2×2 Matrix{Float64}:
561+
0.4 0.0
562+
0.0 0.1
563+
```
564+
565+
"""
566+
function mm_unbalanced(a, b, M, reg_m; kwargs...)
567+
return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...)
568+
end

0 commit comments

Comments
 (0)