@@ -8,7 +8,7 @@ def __init__(self, labelmap, level_weights=None, weight=None):
8
8
torch .nn .Module .__init__ (self )
9
9
self .labelmap = labelmap
10
10
self .level_weights = [1.0 ] * len (self .labelmap .levels ) if level_weights is None else level_weights
11
-
11
+ self . device = torch . device ( "cuda" if torch . cuda . is_available () else "cpu" )
12
12
self .criterion = []
13
13
if weight is None :
14
14
for level_len in self .labelmap .levels :
@@ -22,7 +22,7 @@ def __init__(self, labelmap, level_weights=None, weight=None):
22
22
else :
23
23
level_start .append (level_stop [level_id - 1 ])
24
24
level_stop .append (level_stop [level_id - 1 ] + level_len )
25
- self .criterion .append (nn .CrossEntropyLoss (weight = weight [level_start [level_id ]:level_stop [level_id ]],
25
+ self .criterion .append (nn .CrossEntropyLoss (weight = weight [level_start [level_id ]:level_stop [level_id ]]. to ( self . device ) ,
26
26
reduction = 'none' ))
27
27
28
28
print ('==Using the following weights config for multi level cross entropy loss: {}' .format (self .level_weights ))
@@ -52,6 +52,9 @@ def forward(self, outputs, labels, level_labels):
52
52
class MultiLabelSMLoss (torch .nn .MultiLabelSoftMarginLoss ):
53
53
def __init__ (self , weight = None , size_average = None , reduce = None , reduction = 'mean' ):
54
54
print (weight )
55
+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
56
+ if weight is not None :
57
+ weight = weight .to (self .device )
55
58
torch .nn .MultiLabelSoftMarginLoss .__init__ (self , weight , size_average , reduce , reduction )
56
59
57
60
def forward (self , outputs , labels , level_labels ):
0 commit comments