@@ -57,17 +57,21 @@ def geometry_median(group, graph):
57
57
scores = []
58
58
name , value , axis = group [0 ]
59
59
assert (len (value .shape ) == 4 )
60
- w = value .view ()
61
- channel_num = value .shape [0 ]
62
- w .shape = value .shape [0 ], np .product (value .shape [1 :])
63
- x = w .repeat (channel_num , axis = 0 )
64
- y = np .zeros_like (x )
65
- for i in range (channel_num ):
66
- y [i * channel_num :(i + 1 ) * channel_num ] = np .tile (w [i ],
67
- (channel_num , 1 ))
68
- tmp = np .sqrt (np .sum ((x - y )** 2 , - 1 ))
69
- tmp = tmp .reshape ((channel_num , channel_num ))
70
- tmp = np .sum (tmp , - 1 )
60
+
61
+ def get_distance_sum (value , out_idx ):
62
+ w = value .view ()
63
+ w .shape = value .shape [0 ], np .product (value .shape [1 :])
64
+ selected_filter = np .tile (w [out_idx ], (w .shape [0 ], 1 ))
65
+ x = w - selected_filter
66
+ x = np .sqrt (np .sum (x * x , - 1 ))
67
+ return x .sum ()
68
+
69
+ dist_sum_list = []
70
+ for out_i in range (value .shape [0 ]):
71
+ dist_sum = get_distance_sum (value , out_i )
72
+ dist_sum_list .append (dist_sum )
73
+
74
+ tmp = np .array (dist_sum_list )
71
75
72
76
for name , value , axis in group :
73
77
scores .append ((name , axis , tmp ))
0 commit comments