diff --git a/Project.toml b/Project.toml index b31e817..42f0260 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PythonOT" uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef" authors = ["David Widmann"] -version = "0.1.2" +version = "0.1.3" [deps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" diff --git a/docs/src/api.md b/docs/src/api.md index 157af1d..feaa2d7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -5,6 +5,8 @@ ```@docs emd emd2 +emd_1d +emd2_1d ``` ## Regularized optimal transport diff --git a/src/PythonOT.jl b/src/PythonOT.jl index b1cd684..9b3c534 100644 --- a/src/PythonOT.jl +++ b/src/PythonOT.jl @@ -4,6 +4,8 @@ using PyCall: PyCall export emd, emd2, + emd_1d, + emd2_1d, sinkhorn, sinkhorn2, barycenter, diff --git a/src/lib.jl b/src/lib.jl index da10c79..219820f 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -69,6 +69,78 @@ function emd2(μ, ν, C; kwargs...) return pot.lp.emd2(μ, ν, PyCall.PyReverseDims(permutedims(C)); kwargs...) end +""" + emd_1d(xsource, xtarget; kwargs...) + +Compute the optimal transport plan for the Monge-Kantorovich problem with univariate +discrete measures with support `xsource` and `xtarget` as source and target marginals. + +This function is a wrapper of the function +[`emd_1d`](https://pythonot.github.io/all.html#ot.emd_1d) in the Python Optimal Transport +package. Keyword arguments are listed in the documentation of the Python function. + +# Examples + +```jldoctest +julia> xsource = [0.2, 0.5]; + +julia> xtarget = [0.8, 0.3]; + +julia> emd_1d(xsource, xtarget) +2×2 Matrix{Float64}: + 0.0 0.5 + 0.5 0.0 + +julia> histogram_source = [0.8, 0.2]; + +julia> histogram_target = [0.7, 0.3]; + +julia> emd_1d(xsource, xtarget; a=histogram_source, b=histogram_target) +2×2 Matrix{Float64}: + 0.5 0.3 + 0.2 0.0 +``` + +See also: [`emd`](@ref), [`emd2_1d`](@ref) +""" +function emd_1d(xsource, xtarget; kwargs...) + return pot.lp.emd_1d(xsource, xtarget; kwargs...) +end + +""" + emd2_1d(xsource, xtarget; kwargs...) + +Compute the optimal transport cost for the Monge-Kantorovich problem with univariate +discrete measures with support `xsource` and `xtarget` as source and target marginals. + +This function is a wrapper of the function +[`emd2_1d`](https://pythonot.github.io/all.html#ot.emd2_1d) in the Python Optimal Transport +package. Keyword arguments are listed in the documentation of the Python function. + +# Examples + +```jldoctest +julia> xsource = [0.2, 0.5]; + +julia> xtarget = [0.8, 0.3]; + +julia> round(emd2_1d(xsource, xtarget); sigdigits=6) +0.05 + +julia> histogram_source = [0.8, 0.2]; + +julia> histogram_target = [0.7, 0.3]; + +julia> round(emd2_1d(xsource, xtarget; a=histogram_source, b=histogram_target); sigdigits=6) +0.201 +``` + +See also: [`emd2`](@ref), [`emd2_1d`](@ref) +""" +function emd2_1d(xsource, xtarget; kwargs...) + return pot.lp.emd2_1d(xsource, xtarget; kwargs...) +end + """ sinkhorn(μ, ν, C, ε; kwargs...)