@@ -137,34 +137,34 @@ def prepare_model(self, loading=False):
137
137
# modify last layers based on the model being used
138
138
if not loading :
139
139
if self .model_name in ['alexnet' , 'vgg' ]:
140
- num_features = self .model .classifier [6 ].in_features
140
+ num_features = self .model .module . classifier [6 ].in_features
141
141
if isinstance (self .criterion , LastLevelCELoss ):
142
- self .model .classifier [6 ] = nn .Linear (num_features , self .levels [- 1 ])
142
+ self .model .module . classifier [6 ] = nn .Linear (num_features , self .levels [- 1 ])
143
143
elif isinstance (self .criterion , HierarchicalSoftmaxLoss ):
144
- self .model .classifier [6 ] = HierarchicalSoftmax (labelmap = self .labelmap , input_size = num_features )
144
+ self .model .module . classifier [6 ] = HierarchicalSoftmax (labelmap = self .labelmap , input_size = num_features )
145
145
else :
146
- self .model .classifier [6 ] = nn .Linear (num_features , self .n_classes )
146
+ self .model .module . classifier [6 ] = nn .Linear (num_features , self .n_classes )
147
147
elif 'resnet' in self .model_name :
148
- num_features = self .model .fc .in_features
148
+ num_features = self .model .module . fc .in_features
149
149
if isinstance (self .criterion , LastLevelCELoss ):
150
- self .model .fc = nn .Linear (num_features , self .levels [- 1 ])
150
+ self .model .module . fc = nn .Linear (num_features , self .levels [- 1 ])
151
151
elif isinstance (self .criterion , HierarchicalSoftmaxLoss ):
152
- self .model .fc = HierarchicalSoftmax (labelmap = self .labelmap , input_size = num_features )
152
+ self .model .module . fc = HierarchicalSoftmax (labelmap = self .labelmap , input_size = num_features )
153
153
else :
154
- self .model .fc = nn .Linear (num_features , self .n_classes )
154
+ self .model .module . fc = nn .Linear (num_features , self .n_classes )
155
155
else :
156
156
if self .model_name in ['alexnet' , 'vgg' ]:
157
- num_features = self .model .module .classifier [6 ].in_features
157
+ num_features = self .model .module .module . classifier [6 ].in_features
158
158
if isinstance (self .criterion , LastLevelCELoss ):
159
- self .model .module .classifier [6 ] = nn .Linear (num_features , self .levels [- 1 ])
159
+ self .model .module .module . classifier [6 ] = nn .Linear (num_features , self .levels [- 1 ])
160
160
elif isinstance (self .criterion , HierarchicalSoftmaxLoss ):
161
- self .model .module .classifier [6 ] = HierarchicalSoftmax (labelmap = self .labelmap , input_size = num_features )
161
+ self .model .module .module . classifier [6 ] = HierarchicalSoftmax (labelmap = self .labelmap , input_size = num_features )
162
162
else :
163
- self .model .module .classifier [6 ] = nn .Linear (num_features , self .n_classes )
163
+ self .model .module .module . classifier [6 ] = nn .Linear (num_features , self .n_classes )
164
164
elif 'resnet' in self .model_name :
165
- num_features = self .model .module .fc .in_features
165
+ num_features = self .model .module .module . fc .in_features
166
166
if isinstance (self .criterion , LastLevelCELoss ):
167
- self .model .module .fc = nn .Linear (num_features , self .levels [- 1 ])
167
+ self .model .module .module . fc = nn .Linear (num_features , self .levels [- 1 ])
168
168
elif isinstance (self .criterion , HierarchicalSoftmaxLoss ):
169
169
self .model .module .fc = HierarchicalSoftmax (labelmap = self .labelmap , input_size = num_features )
170
170
else :
0 commit comments