@@ -108,8 +108,8 @@ <h1>Source code for ot.lp</h1><div class="highlight"><pre>
108
108
109
109
< span class ="c1 "> # import compiled emd</ span >
110
110
< span class ="kn "> from</ span > < span class ="nn "> .emd_wrap</ span > < span class ="kn "> import</ span > < span class ="n "> emd_c</ span > < span class ="p "> ,</ span > < span class ="n "> check_result</ span > < span class ="p "> ,</ span > < span class ="n "> emd_1d_sorted</ span >
111
- < span class ="kn "> from</ span > < span class ="nn "> .solver_1d</ span > < span class ="kn "> import</ span > < span class ="p "> (</ span > < span class ="n "> emd_1d</ span > < span class ="p "> ,</ span > < span class ="n "> emd2_1d</ span > < span class ="p "> ,</ span > < span class ="n "> wasserstein_1d</ span > < span class ="p "> ,</ span >
112
- < span class ="n "> binary_search_circle</ span > < span class ="p "> ,</ span > < span class ="n "> wasserstein_circle</ span > < span class ="p "> ,</ span >
111
+ < span class ="kn "> from</ span > < span class ="nn "> .solver_1d</ span > < span class ="kn "> import</ span > < span class ="p "> (</ span > < span class ="n "> emd_1d</ span > < span class ="p "> ,</ span > < span class ="n "> emd2_1d</ span > < span class ="p "> ,</ span > < span class ="n "> wasserstein_1d</ span > < span class ="p "> ,</ span >
112
+ < span class ="n "> binary_search_circle</ span > < span class ="p "> ,</ span > < span class ="n "> wasserstein_circle</ span > < span class ="p "> ,</ span >
113
113
< span class ="n "> semidiscrete_wasserstein2_unif_circle</ span > < span class ="p "> )</ span >
114
114
115
115
< span class ="kn "> from</ span > < span class ="nn "> ..utils</ span > < span class ="kn "> import</ span > < span class ="n "> dist</ span > < span class ="p "> ,</ span > < span class ="n "> list_to_array</ span >
@@ -360,7 +360,7 @@ <h1>Source code for ot.lp</h1><div class="highlight"><pre>
360
360
< span class ="sd "> check_marginals: bool, optional (default=True)</ span >
361
361
< span class ="sd "> If True, checks that the marginals mass are equal. If False, skips the</ span >
362
362
< span class ="sd "> check.</ span >
363
- < span class =" sd " > </ span >
363
+
364
364
365
365
< span class ="sd "> Returns</ span >
366
366
< span class ="sd "> -------</ span >
@@ -439,8 +439,8 @@ <h1>Source code for ot.lp</h1><div class="highlight"><pre>
439
439
< span class ="c1 "> # ensure that same mass</ span >
440
440
< span class ="k "> if</ span > < span class ="n "> check_marginals</ span > < span class ="p "> :</ span >
441
441
< span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> testing</ span > < span class ="o "> .</ span > < span class ="n "> assert_almost_equal</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> ),</ span >
442
- < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> ),</ span > < span class ="n "> err_msg</ span > < span class ="o "> =</ span > < span class ="s1 "> 'a and b vector must have the same sum'</ span > < span class ="p "> ,</ span >
443
- < span class ="n "> decimal</ span > < span class ="o "> =</ span > < span class ="mi "> 6</ span > < span class ="p "> )</ span >
442
+ < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> ),</ span > < span class ="n "> err_msg</ span > < span class ="o "> =</ span > < span class ="s1 "> 'a and b vector must have the same sum'</ span > < span class ="p "> ,</ span >
443
+ < span class ="n "> decimal</ span > < span class ="o "> =</ span > < span class ="mi "> 6</ span > < span class ="p "> )</ span >
444
444
< span class ="n "> b</ span > < span class ="o "> =</ span > < span class ="n "> b</ span > < span class ="o "> *</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> ()</ span > < span class ="o "> /</ span > < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> ()</ span >
445
445
446
446
< span class ="n "> asel</ span > < span class ="o "> =</ span > < span class ="n "> a</ span > < span class ="o "> !=</ span > < span class ="mi "> 0</ span >
@@ -541,8 +541,8 @@ <h1>Source code for ot.lp</h1><div class="highlight"><pre>
541
541
< span class ="sd "> check_marginals: bool, optional (default=True)</ span >
542
542
< span class ="sd "> If True, checks that the marginals mass are equal. If False, skips the</ span >
543
543
< span class ="sd "> check.</ span >
544
- < span class =" sd " > </ span >
545
- < span class =" sd " > </ span >
544
+
545
+
546
546
< span class ="sd "> Returns</ span >
547
547
< span class ="sd "> -------</ span >
548
548
< span class ="sd "> W: float, array-like</ span >
@@ -607,16 +607,15 @@ <h1>Source code for ot.lp</h1><div class="highlight"><pre>
607
607
< span class ="n "> b</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> asarray</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> float64</ span > < span class ="p "> )</ span >
608
608
< span class ="n "> M</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> asarray</ span > < span class ="p "> (</ span > < span class ="n "> M</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> float64</ span > < span class ="p "> ,</ span > < span class ="n "> order</ span > < span class ="o "> =</ span > < span class ="s1 "> 'C'</ span > < span class ="p "> )</ span >
609
609
610
-
611
610
< span class ="k "> assert</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span > < span class ="o "> ==</ span > < span class ="n "> M</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span > < span class ="ow "> and</ span > < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span > < span class ="o "> ==</ span > < span class ="n "> M</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]),</ span > \
612
611
< span class ="s2 "> "Dimension mismatch, check dimensions of M with a and b"</ span >
613
612
614
613
< span class ="c1 "> # ensure that same mass</ span >
615
614
< span class ="k "> if</ span > < span class ="n "> check_marginals</ span > < span class ="p "> :</ span >
616
615
< span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> testing</ span > < span class ="o "> .</ span > < span class ="n "> assert_almost_equal</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> ),</ span >
617
- < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> ,</ span > < span class ="n "> keepdims</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ),</ span > < span class ="n "> err_msg</ span > < span class ="o "> =</ span > < span class ="s1 "> 'a and b vector must have the same sum'</ span > < span class ="p "> ,</ span >
618
- < span class ="n "> decimal</ span > < span class ="o "> =</ span > < span class ="mi "> 6</ span > < span class ="p "> )</ span >
619
- < span class ="n "> b</ span > < span class ="o "> =</ span > < span class ="n "> b</ span > < span class ="o "> *</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> )</ span > < span class ="o "> /</ span > < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> ,</ span > < span class ="n "> keepdims</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
616
+ < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> ,</ span > < span class ="n "> keepdims</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ),</ span > < span class ="n "> err_msg</ span > < span class ="o "> =</ span > < span class ="s1 "> 'a and b vector must have the same sum'</ span > < span class ="p "> ,</ span >
617
+ < span class ="n "> decimal</ span > < span class ="o "> =</ span > < span class ="mi "> 6</ span > < span class ="p "> )</ span >
618
+ < span class ="n "> b</ span > < span class ="o "> =</ span > < span class ="n "> b</ span > < span class ="o "> *</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> )</ span > < span class ="o "> /</ span > < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> (</ span > < span class ="mi "> 0</ span > < span class ="p "> ,</ span > < span class ="n "> keepdims</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
620
619
621
620
< span class ="n "> asel</ span > < span class ="o "> =</ span > < span class ="n "> a</ span > < span class ="o "> !=</ span > < span class ="mi "> 0</ span >
622
621
0 commit comments