Skip to content

Commit 343890b

Browse files
authored
fix fpgm pruning memory bug runing on Windows test=develop (#285)
1 parent b571201 commit 343890b

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

paddleslim/prune/criterion.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,21 @@ def geometry_median(group, graph):
5757
scores = []
5858
name, value, axis = group[0]
5959
assert (len(value.shape) == 4)
60-
w = value.view()
61-
channel_num = value.shape[0]
62-
w.shape = value.shape[0], np.product(value.shape[1:])
63-
x = w.repeat(channel_num, axis=0)
64-
y = np.zeros_like(x)
65-
for i in range(channel_num):
66-
y[i * channel_num:(i + 1) * channel_num] = np.tile(w[i],
67-
(channel_num, 1))
68-
tmp = np.sqrt(np.sum((x - y)**2, -1))
69-
tmp = tmp.reshape((channel_num, channel_num))
70-
tmp = np.sum(tmp, -1)
60+
61+
def get_distance_sum(value, out_idx):
62+
w = value.view()
63+
w.shape = value.shape[0], np.product(value.shape[1:])
64+
selected_filter = np.tile(w[out_idx], (w.shape[0], 1))
65+
x = w - selected_filter
66+
x = np.sqrt(np.sum(x * x, -1))
67+
return x.sum()
68+
69+
dist_sum_list = []
70+
for out_i in range(value.shape[0]):
71+
dist_sum = get_distance_sum(value, out_i)
72+
dist_sum_list.append(dist_sum)
73+
74+
tmp = np.array(dist_sum_list)
7175

7276
for name, value, axis in group:
7377
scores.append((name, axis, tmp))

0 commit comments

Comments
 (0)