Skip to content

Commit 170b862

Browse files
authored
Add smooth_ot_dual (#7)
1 parent bd09ed6 commit 170b862

File tree

4 files changed

+75
-2
lines changed

4 files changed

+75
-2
lines changed

docs/make.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using PythonOT
1010
DocMeta.setdocmeta!(PythonOT, :DocTestSetup, :(using PythonOT); recursive=true)
1111

1212
makedocs(;
13-
modules=[PythonOT],
13+
modules=[PythonOT, PythonOT.Smooth],
1414
authors="David Widmann",
1515
repo="https://github.com/devmotion/PythonOT.jl/blob/{commit}{path}#{line}",
1616
sitename="PythonOT.jl",

docs/src/api.md

+12-1
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,25 @@ emd
77
emd2
88
```
99

10-
## Entropically regularised optimal transport
10+
## Regularized optimal transport
1111

1212
```@docs
1313
sinkhorn
1414
sinkhorn2
1515
barycenter
1616
```
1717

18+
The submodule `Smooth` contains a function for solving regularized optimal
19+
transport problems with L2- and entropic regularization using the dual
20+
formulation. You can load the submodule with
21+
```julia
22+
using PythonOT.Smooth
23+
```
24+
25+
```@docs
26+
PythonOT.Smooth.smooth_ot_dual
27+
```
28+
1829
## Unbalanced optimal transport
1930

2031
```@docs

src/PythonOT.jl

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export emd, emd2, sinkhorn, sinkhorn2, barycenter, sinkhorn_unbalanced, sinkhorn
77
const pot = PyCall.PyNULL()
88

99
include("lib.jl")
10+
include("smooth.jl")
1011

1112
function __init__()
1213
return copy!(pot, PyCall.pyimport_conda("ot", "pot", "conda-forge"))

src/smooth.jl

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
module Smooth
2+
3+
using ..PythonOT: PythonOT
4+
using ..PyCall: PyCall
5+
6+
export smooth_ot_dual
7+
8+
"""
9+
smooth_ot_dual(μ, ν, C, ε; reg_type="l2", kwargs...)
10+
11+
Compute the optimal transport plan for a regularized optimal transport problem
12+
with source and target marginals `μ` and `ν`, cost matrix `C` of size
13+
`(length(μ), length(ν))`, and regularization parameter `ε`.
14+
15+
The optimal transport map `γ` is of the same size as `C` and solves
16+
```math
17+
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle
18+
+ \\varepsilon \\Omega(\\gamma),
19+
```
20+
where ``\\Omega(\\gamma)`` is the L2-regularization term
21+
``\\Omega(\\gamma) = \\|\\gamma\\|_F^2/2`` if `reg_type="l2"` (the default) or
22+
the entropic regularization term
23+
``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` if `reg_type="kl"`.
24+
25+
The function solves the dual formulation[^BSR2018]
26+
```math
27+
\\max_{\\alpha, \\beta} \\mu^{\\mathsf{T}} \\alpha + \\nu^{\\mathsf{T}} \\beta −
28+
\\sum_{j} \\delta_{\\Omega}(\\alpha + \\beta_j - C_j),
29+
```
30+
where ``C_j`` is the ``j``th column of the cost matrix and ``\\delta_{\\Omega}`` is the
31+
conjugate of the regularization term ``\\Omega``.
32+
33+
This function is a wrapper of the function
34+
[`smooth_ot_dual`](https://pythonot.github.io/gen_modules/ot.smooth.html#ot.smooth.smooth_ot_dual)
35+
in the Python Optimal Transport package. Keyword arguments are listed in the documentation
36+
of the Python function.
37+
38+
# Examples
39+
40+
```jldoctest; setup=:(using PythonOT.Smooth)
41+
julia> μ = [0.5, 0.2, 0.3];
42+
43+
julia> ν = [0.0, 1.0];
44+
45+
julia> C = [0.0 1.0;
46+
2.0 0.0;
47+
0.5 1.5];
48+
49+
julia> smooth_ot_dual(μ, ν, C, 0.01)
50+
3×2 Matrix{Float64}:
51+
0.0 0.5
52+
0.0 0.2
53+
0.0 0.300001
54+
```
55+
56+
[^BSR2018]: Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. In *Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS)*.
57+
"""
58+
function smooth_ot_dual(μ, ν, C, ε; kwargs...)
59+
return PythonOT.pot.smooth.smooth_ot_dual(μ, ν, C, ε; kwargs...)
60+
end
61+
end

0 commit comments

Comments
 (0)