Skip to content

Commit f2638c1

Browse files
committed
move weights to gpu/cpu; add stats for num of pred per sample
1 parent 107e9e5 commit f2638c1

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

network/evaluation.py

+6
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ def evaluate(self, predicted_scores, correct_labels, epoch, phase, save_to_tenso
200200

201201
self.predicted_labels = predicted_scores >= np.tile(self.optimal_thresholds, (correct_labels.shape[0], 1))
202202

203+
classes_predicted_per_sample = np.sum(self.predicted_labels, axis=1)
204+
print("Max: {}".format(np.max(classes_predicted_per_sample)))
205+
print("Min: {}".format(np.min(classes_predicted_per_sample)))
206+
print("Mean: {}".format(np.mean(classes_predicted_per_sample)))
207+
print("std: {}".format(np.std(classes_predicted_per_sample)))
208+
203209
level_stop, level_start = [], []
204210
for level_id, level_len in enumerate(self.labelmap.levels):
205211
if level_id == 0:

network/loss.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def __init__(self, labelmap, level_weights=None, weight=None):
88
torch.nn.Module.__init__(self)
99
self.labelmap = labelmap
1010
self.level_weights = [1.0] * len(self.labelmap.levels) if level_weights is None else level_weights
11-
11+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1212
self.criterion = []
1313
if weight is None:
1414
for level_len in self.labelmap.levels:
@@ -22,7 +22,7 @@ def __init__(self, labelmap, level_weights=None, weight=None):
2222
else:
2323
level_start.append(level_stop[level_id - 1])
2424
level_stop.append(level_stop[level_id - 1] + level_len)
25-
self.criterion.append(nn.CrossEntropyLoss(weight=weight[level_start[level_id]:level_stop[level_id]],
25+
self.criterion.append(nn.CrossEntropyLoss(weight=weight[level_start[level_id]:level_stop[level_id]].to(self.device),
2626
reduction='none'))
2727

2828
print('==Using the following weights config for multi level cross entropy loss: {}'.format(self.level_weights))
@@ -52,6 +52,9 @@ def forward(self, outputs, labels, level_labels):
5252
class MultiLabelSMLoss(torch.nn.MultiLabelSoftMarginLoss):
5353
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
5454
print(weight)
55+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56+
if weight is not None:
57+
weight = weight.to(self.device)
5558
torch.nn.MultiLabelSoftMarginLoss.__init__(self, weight, size_average, reduce, reduction)
5659

5760
def forward(self, outputs, labels, level_labels):

0 commit comments

Comments
 (0)