Skip to content

Commit bd09ed6

Browse files
authored
Add barycenter (#4)
1 parent db36742 commit bd09ed6

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

docs/src/api.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ emd2
1212
```@docs
1313
sinkhorn
1414
sinkhorn2
15+
barycenter
1516
```
1617

1718
## Unbalanced optimal transport

src/PythonOT.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module PythonOT
22

33
using PyCall: PyCall
44

5-
export emd, emd2, sinkhorn, sinkhorn2, sinkhorn_unbalanced, sinkhorn_unbalanced2
5+
export emd, emd2, sinkhorn, sinkhorn2, barycenter, sinkhorn_unbalanced, sinkhorn_unbalanced2
66

77
const pot = PyCall.PyNULL()
88

src/lib.jl

+55
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ julia> emd(μ, ν, C)
3030
0.0 0.2
3131
0.0 0.3
3232
```
33+
34+
See also: [`emd2`](@ref)
3335
"""
3436
function emd(μ, ν, C; kwargs...)
3537
return pot.lp.emd(μ, ν, PyCall.PyReverseDims(permutedims(C)); kwargs...)
@@ -64,6 +66,8 @@ julia> C = [0.0 1.0;
6466
julia> emd2(μ, ν, C)
6567
0.95
6668
```
69+
70+
See also: [`emd`](@ref)
6771
"""
6872
function emd2(μ, ν, C; kwargs...)
6973
return pot.lp.emd2(μ, ν, PyCall.PyReverseDims(permutedims(C)); kwargs...)
@@ -105,6 +109,8 @@ julia> sinkhorn(μ, ν, C, 0.01)
105109
0.0 0.2
106110
0.0 0.3
107111
```
112+
113+
See also: [`sinkhorn2`](@ref)
108114
"""
109115
function sinkhorn(μ, ν, C, ε; kwargs...)
110116
return pot.sinkhorn(μ, ν, PyCall.PyReverseDims(permutedims(C)), ε; kwargs...)
@@ -144,6 +150,8 @@ julia> round.(sinkhorn2(μ, ν, C, 0.01); sigdigits=6)
144150
1-element Vector{Float64}:
145151
0.95
146152
```
153+
154+
See also: [`sinkhorn`](@ref)
147155
"""
148156
function sinkhorn2(μ, ν, C, ε; kwargs...)
149157
return pot.sinkhorn2(μ, ν, PyCall.PyReverseDims(permutedims(C)), ε; kwargs...)
@@ -189,6 +197,8 @@ julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000)
189197
0.0 0.200188
190198
0.0 0.29983
191199
```
200+
201+
See also: [`sinkhorn_unbalanced2`](@ref)
192202
"""
193203
function sinkhorn_unbalanced(μ, ν, C, ε, λ; kwargs...)
194204
return pot.sinkhorn_unbalanced(
@@ -234,9 +244,54 @@ julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
234244
1-element Vector{Float64}:
235245
0.949709
236246
```
247+
248+
See also: [`sinkhorn_unbalanced`](@ref)
237249
"""
238250
function sinkhorn_unbalanced2(μ, ν, C, ε, λ; kwargs...)
239251
return pot.sinkhorn_unbalanced2(
240252
μ, ν, PyCall.PyReverseDims(permutedims(C)), ε, λ; kwargs...
241253
)
242254
end
255+
256+
"""
257+
barycenter(A, C, ε; kwargs...)
258+
259+
Compute the entropically regularized Wasserstein barycenter with histograms `A`, cost matrix
260+
`C`, and entropic regularization parameter `ε`.
261+
262+
The Wasserstein barycenter is a histogram and solves
263+
```math
264+
\\inf_{a} \\sum_{i} W_{\\varepsilon,C}(a, a_i),
265+
```
266+
where the histograms ``a_i`` are columns of matrix `A` and ``W_{\\varepsilon,C}(a, a_i)}``
267+
is the optimal transport cost for the entropically regularized optimal transport problem
268+
with marginals ``a`` and ``a_i``, cost matrix ``C``, and entropic regularization parameter
269+
``\\varepsilon``. Optionally, weights of the histograms ``a_i`` can be provided with the
270+
keyword argument `weights`.
271+
272+
This function is a wrapper of the function
273+
[`barycenter`](https://pythonot.github.io/all.html#ot.barycenter) in the
274+
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
275+
Python function.
276+
277+
# Examples
278+
279+
```jldoctest
280+
julia> A = rand(10, 3);
281+
282+
julia> A ./= sum(A; dims=1);
283+
284+
julia> C = rand(10, 10);
285+
286+
julia> isapprox(sum(barycenter(A, C, 0.01; method="sinkhorn_stabilized")), 1; atol=1e-4)
287+
true
288+
```
289+
"""
290+
function barycenter(A, C, ε; kwargs...)
291+
return pot.barycenter(
292+
PyCall.PyReverseDims(permutedims(A)),
293+
PyCall.PyReverseDims(permutedims(C)),
294+
ε;
295+
kwargs...,
296+
)
297+
end

0 commit comments

Comments
 (0)