Skip to content

Commit c7d0f05

Browse files
Added empircal_sinkhorn_divergence function (#13)
Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent af649a0 commit c7d0f05

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PythonOT"
22
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
33
authors = ["David Widmann"]
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
[deps]
77
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ emd2_1d
1414
```@docs
1515
sinkhorn
1616
sinkhorn2
17+
empirical_sinkhorn_divergence
1718
barycenter
1819
```
1920

src/PythonOT.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ export emd,
1111
barycenter,
1212
barycenter_unbalanced,
1313
sinkhorn_unbalanced,
14-
sinkhorn_unbalanced2
14+
sinkhorn_unbalanced2,
15+
empirical_sinkhorn_divergence
1516

1617
const pot = PyCall.PyNULL()
1718

src/lib.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,43 @@ function sinkhorn2(μ, ν, C, ε; kwargs...)
244244
return pot.sinkhorn2(μ, ν, PyCall.PyReverseDims(permutedims(C)), ε; kwargs...)
245245
end
246246

247+
"""
248+
empirical_sinkhorn_divergence(xsource, xtarget, ε; kwargs...)
249+
250+
Compute the Sinkhorn divergence from empirical data, where `xsource` and `xtarget` are
251+
arrays representing samples in the source domain and target domain, respectively, and `ε`
252+
is the regularization term.
253+
254+
This function is a wrapper of the function
255+
[`ot.bregman.empirical_sinkhorn_divergence`](https://pythonot.github.io/gen_modules/ot.bregman.html#ot.bregman.empirical_sinkhorn_divergence)
256+
in the Python Optimal Transport package. Keyword arguments are listed in the documentation of the Python function.
257+
258+
# Examples
259+
260+
```jldoctest
261+
julia> xsource = [1];
262+
263+
julia> xtarget = [2, 3];
264+
265+
julia> ε = 0.01;
266+
267+
julia> empirical_sinkhorn_divergence(xsource, xtarget, ε) ≈
268+
sinkhorn2([1], [0.5, 0.5], [1 4], ε) -
269+
(
270+
sinkhorn2([1], [1], zeros(1, 1), ε) +
271+
sinkhorn2([0.5, 0.5], [0.5, 0.5], [0 1; 1 0], ε)
272+
) / 2
273+
true
274+
```
275+
276+
See also: [`sinkhorn2`](@ref)
277+
"""
278+
function empirical_sinkhorn_divergence(xsource, xtarget, ε; kwargs...)
279+
return pot.bregman.empirical_sinkhorn_divergence(
280+
reshape(xsource, Val(2)), reshape(xtarget, Val(2)), ε; kwargs...
281+
)
282+
end
283+
247284
"""
248285
sinkhorn_unbalanced(μ, ν, C, ε, λ; kwargs...)
249286

0 commit comments

Comments
 (0)