@@ -120,7 +120,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
120
120
121
121
nx = get_backend (a , b , M )
122
122
123
- if nx .sum (a ) > 1 or nx .sum (b ) > 1 :
123
+ if nx .sum (a ) > 1 + 1e-15 or nx .sum (b ) > 1 + 1e-15 : # 1e-15 for numerical errors
124
124
raise ValueError ("Problem infeasible. Check that a and b are in the "
125
125
"simplex" )
126
126
@@ -270,36 +270,43 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
270
270
271
271
nx = get_backend (a , b , M )
272
272
273
+ dim_a , dim_b = M .shape
274
+ if len (a ) == 0 :
275
+ a = nx .ones (dim_a , type_as = a ) / dim_a
276
+ if len (b ) == 0 :
277
+ b = nx .ones (dim_b , type_as = b ) / dim_b
278
+
273
279
if m is None :
274
280
return partial_wasserstein_lagrange (a , b , M , log = log , ** kwargs )
275
281
elif m < 0 :
276
282
raise ValueError ("Problem infeasible. Parameter m should be greater"
277
283
" than 0." )
278
- elif m > nx .min (( nx .sum (a ), nx .sum (b ))):
284
+ elif m > nx .min (nx . stack (( nx .sum (a ), nx .sum (b ) ))):
279
285
raise ValueError ("Problem infeasible. Parameter m should lower or"
280
286
" equal than min(|a|_1, |b|_1)." )
281
287
282
- a0 , b0 , M0 = a , b , M
283
- # convert to humpy
284
- a , b , M = nx .to_numpy (a , b , M )
285
-
286
- b_extended = np .append (b , [(np .sum (a ) - m ) / nb_dummies ] * nb_dummies )
287
- a_extended = np .append (a , [(np .sum (b ) - m ) / nb_dummies ] * nb_dummies )
288
- M_extended = np .zeros ((len (a_extended ), len (b_extended )))
289
- M_extended [- nb_dummies :, - nb_dummies :] = np .max (M ) * 2
290
- M_extended [:len (a ), :len (b )] = M
288
+ b_extension = nx .ones (nb_dummies , type_as = b ) * (nx .sum (a ) - m ) / nb_dummies
289
+ b_extended = nx .concatenate ((b , b_extension ))
290
+ a_extension = nx .ones (nb_dummies , type_as = a ) * (nx .sum (b ) - m ) / nb_dummies
291
+ a_extended = nx .concatenate ((a , a_extension ))
292
+ M_extension = nx .ones ((nb_dummies , nb_dummies ), type_as = M ) * nx .max (M ) * 2
293
+ M_extended = nx .concatenate (
294
+ (nx .concatenate ((M , nx .zeros ((M .shape [0 ], M_extension .shape [1 ]))), axis = 1 ),
295
+ nx .concatenate ((nx .zeros ((M_extension .shape [0 ], M .shape [1 ])), M_extension ), axis = 1 )),
296
+ axis = 0
297
+ )
291
298
292
299
gamma , log_emd = emd (a_extended , b_extended , M_extended , log = True ,
293
300
** kwargs )
294
301
295
- gamma = nx . from_numpy ( gamma [:len (a ), :len (b )], type_as = M )
302
+ gamma = gamma [:len (a ), :len (b )]
296
303
297
304
if log_emd ['warning' ] is not None :
298
305
raise ValueError ("Error in the EMD resolution: try to increase the"
299
306
" number of dummy points" )
300
- log_emd ['partial_w_dist' ] = nx .sum (M0 * gamma )
301
- log_emd ['u' ] = nx . from_numpy ( log_emd ['u' ][:len (a )], type_as = a0 )
302
- log_emd ['v' ] = nx . from_numpy ( log_emd ['v' ][:len (b )], type_as = b0 )
307
+ log_emd ['partial_w_dist' ] = nx .sum (M * gamma )
308
+ log_emd ['u' ] = log_emd ['u' ][:len (a )]
309
+ log_emd ['v' ] = log_emd ['v' ][:len (b )]
303
310
304
311
if log :
305
312
return gamma , log_emd
@@ -389,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
389
396
NeurIPS.
390
397
"""
391
398
399
+ a , b , M = list_to_array (a , b , M )
400
+
401
+ nx = get_backend (a , b , M )
402
+
392
403
partial_gw , log_w = partial_wasserstein (a , b , M , m , nb_dummies , log = True ,
393
404
** kwargs )
394
405
log_w ['T' ] = partial_gw
395
406
396
407
if log :
397
- return np .sum (partial_gw * M ), log_w
408
+ return nx .sum (partial_gw * M ), log_w
398
409
else :
399
- return np .sum (partial_gw * M )
410
+ return nx .sum (partial_gw * M )
400
411
401
412
402
413
def gwgrad_partial (C1 , C2 , T ):
@@ -838,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
838
849
ot.partial.partial_wasserstein: exact Partial Wasserstein
839
850
"""
840
851
841
- a = np . asarray (a , dtype = np . float64 )
842
- b = np . asarray ( b , dtype = np . float64 )
843
- M = np . asarray ( M , dtype = np . float64 )
852
+ a , b , M = list_to_array (a , b , M )
853
+
854
+ nx = get_backend ( a , b , M )
844
855
845
856
dim_a , dim_b = M .shape
846
- dx = np .ones (dim_a , dtype = np . float64 )
847
- dy = np .ones (dim_b , dtype = np . float64 )
857
+ dx = nx .ones (dim_a , type_as = a )
858
+ dy = nx .ones (dim_b , type_as = b )
848
859
849
860
if len (a ) == 0 :
850
- a = np .ones (dim_a , dtype = np . float64 ) / dim_a
861
+ a = nx .ones (dim_a , type_as = a ) / dim_a
851
862
if len (b ) == 0 :
852
- b = np .ones (dim_b , dtype = np . float64 ) / dim_b
863
+ b = nx .ones (dim_b , type_as = b ) / dim_b
853
864
854
865
if m is None :
855
- m = np .min (( np .sum (a ), np .sum (b ))) * 1.0
866
+ m = nx .min (nx . stack (( nx .sum (a ), nx .sum (b ) ))) * 1.0
856
867
if m < 0 :
857
868
raise ValueError ("Problem infeasible. Parameter m should be greater"
858
869
" than 0." )
859
- if m > np .min (( np .sum (a ), np .sum (b ))):
870
+ if m > nx .min (nx . stack (( nx .sum (a ), nx .sum (b ) ))):
860
871
raise ValueError ("Problem infeasible. Parameter m should lower or"
861
872
" equal than min(|a|_1, |b|_1)." )
862
873
863
874
log_e = {'err' : []}
864
875
865
- # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute
866
- K = np .empty (M .shape , dtype = M .dtype )
867
- np .divide (M , - reg , out = K )
868
- np .exp (K , out = K )
869
- np .multiply (K , m / np .sum (K ), out = K )
876
+ if type (a ) == type (b ) == type (M ) == np .ndarray :
877
+ # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
878
+ K = np .empty (M .shape , dtype = M .dtype )
879
+ np .divide (M , - reg , out = K )
880
+ np .exp (K , out = K )
881
+ np .multiply (K , m / np .sum (K ), out = K )
882
+ else :
883
+ K = nx .exp (- M / reg )
884
+ K = K * m / nx .sum (K )
870
885
871
886
err , cpt = 1 , 0
872
- q1 = np .ones (K .shape )
873
- q2 = np .ones (K .shape )
874
- q3 = np .ones (K .shape )
887
+ q1 = nx .ones (K .shape , type_as = K )
888
+ q2 = nx .ones (K .shape , type_as = K )
889
+ q3 = nx .ones (K .shape , type_as = K )
875
890
876
891
while (err > stopThr and cpt < numItermax ):
877
892
Kprev = K
878
893
K = K * q1
879
- K1 = np .dot (np .diag (np .minimum (a / np .sum (K , axis = 1 ), dx )), K )
894
+ K1 = nx .dot (nx .diag (nx .minimum (a / nx .sum (K , axis = 1 ), dx )), K )
880
895
q1 = q1 * Kprev / K1
881
896
K1prev = K1
882
897
K1 = K1 * q2
883
- K2 = np .dot (K1 , np .diag (np .minimum (b / np .sum (K1 , axis = 0 ), dy )))
898
+ K2 = nx .dot (K1 , nx .diag (nx .minimum (b / nx .sum (K1 , axis = 0 ), dy )))
884
899
q2 = q2 * K1prev / K2
885
900
K2prev = K2
886
901
K2 = K2 * q3
887
- K = K2 * (m / np .sum (K2 ))
902
+ K = K2 * (m / nx .sum (K2 ))
888
903
q3 = q3 * K2prev / K
889
904
890
- if np .any (np .isnan (K )) or np .any (np .isinf (K )):
905
+ if nx .any (nx .isnan (K )) or nx .any (nx .isinf (K )):
891
906
print ('Warning: numerical errors at iteration' , cpt )
892
907
break
893
908
if cpt % 10 == 0 :
894
- err = np . linalg .norm (Kprev - K )
909
+ err = nx .norm (Kprev - K )
895
910
if log :
896
911
log_e ['err' ].append (err )
897
912
if verbose :
@@ -901,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
901
916
print ('{:5d}|{:8e}|' .format (cpt , err ))
902
917
903
918
cpt = cpt + 1
904
- log_e ['partial_w_dist' ] = np .sum (M * K )
919
+ log_e ['partial_w_dist' ] = nx .sum (M * K )
905
920
if log :
906
921
return K , log_e
907
922
else :
0 commit comments