@@ -30,6 +30,8 @@ julia> emd(μ, ν, C)
30
30
0.0 0.2
31
31
0.0 0.3
32
32
```
33
+
34
+ See also: [`emd2`](@ref)
33
35
"""
34
36
function emd (μ, ν, C; kwargs... )
35
37
return pot. lp. emd (μ, ν, PyCall. PyReverseDims (permutedims (C)); kwargs... )
@@ -64,6 +66,8 @@ julia> C = [0.0 1.0;
64
66
julia> emd2(μ, ν, C)
65
67
0.95
66
68
```
69
+
70
+ See also: [`emd`](@ref)
67
71
"""
68
72
function emd2 (μ, ν, C; kwargs... )
69
73
return pot. lp. emd2 (μ, ν, PyCall. PyReverseDims (permutedims (C)); kwargs... )
@@ -105,6 +109,8 @@ julia> sinkhorn(μ, ν, C, 0.01)
105
109
0.0 0.2
106
110
0.0 0.3
107
111
```
112
+
113
+ See also: [`sinkhorn2`](@ref)
108
114
"""
109
115
function sinkhorn (μ, ν, C, ε; kwargs... )
110
116
return pot. sinkhorn (μ, ν, PyCall. PyReverseDims (permutedims (C)), ε; kwargs... )
@@ -144,6 +150,8 @@ julia> round.(sinkhorn2(μ, ν, C, 0.01); sigdigits=6)
144
150
1-element Vector{Float64}:
145
151
0.95
146
152
```
153
+
154
+ See also: [`sinkhorn`](@ref)
147
155
"""
148
156
function sinkhorn2 (μ, ν, C, ε; kwargs... )
149
157
return pot. sinkhorn2 (μ, ν, PyCall. PyReverseDims (permutedims (C)), ε; kwargs... )
@@ -189,6 +197,8 @@ julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000)
189
197
0.0 0.200188
190
198
0.0 0.29983
191
199
```
200
+
201
+ See also: [`sinkhorn_unbalanced2`](@ref)
192
202
"""
193
203
function sinkhorn_unbalanced (μ, ν, C, ε, λ; kwargs... )
194
204
return pot. sinkhorn_unbalanced (
@@ -234,9 +244,54 @@ julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
234
244
1-element Vector{Float64}:
235
245
0.949709
236
246
```
247
+
248
+ See also: [`sinkhorn_unbalanced`](@ref)
237
249
"""
238
250
function sinkhorn_unbalanced2 (μ, ν, C, ε, λ; kwargs... )
239
251
return pot. sinkhorn_unbalanced2 (
240
252
μ, ν, PyCall. PyReverseDims (permutedims (C)), ε, λ; kwargs...
241
253
)
242
254
end
255
+
256
+ """
257
+ barycenter(A, C, ε; kwargs...)
258
+
259
+ Compute the entropically regularized Wasserstein barycenter with histograms `A`, cost matrix
260
+ `C`, and entropic regularization parameter `ε`.
261
+
262
+ The Wasserstein barycenter is a histogram and solves
263
+ ```math
264
+ \\ inf_{a} \\ sum_{i} W_{\\ varepsilon,C}(a, a_i),
265
+ ```
266
+ where the histograms ``a_i`` are columns of matrix `A` and ``W_{\\ varepsilon,C}(a, a_i)}``
267
+ is the optimal transport cost for the entropically regularized optimal transport problem
268
+ with marginals ``a`` and ``a_i``, cost matrix ``C``, and entropic regularization parameter
269
+ ``\\ varepsilon``. Optionally, weights of the histograms ``a_i`` can be provided with the
270
+ keyword argument `weights`.
271
+
272
+ This function is a wrapper of the function
273
+ [`barycenter`](https://pythonot.github.io/all.html#ot.barycenter) in the
274
+ Python Optimal Transport package. Keyword arguments are listed in the documentation of the
275
+ Python function.
276
+
277
+ # Examples
278
+
279
+ ```jldoctest
280
+ julia> A = rand(10, 3);
281
+
282
+ julia> A ./= sum(A; dims=1);
283
+
284
+ julia> C = rand(10, 10);
285
+
286
+ julia> isapprox(sum(barycenter(A, C, 0.01; method="sinkhorn_stabilized")), 1; atol=1e-4)
287
+ true
288
+ ```
289
+ """
290
+ function barycenter (A, C, ε; kwargs... )
291
+ return pot. barycenter (
292
+ PyCall. PyReverseDims (permutedims (A)),
293
+ PyCall. PyReverseDims (permutedims (C)),
294
+ ε;
295
+ kwargs... ,
296
+ )
297
+ end
0 commit comments