Skip to content

Commit 09771ec

Browse files
committed
replace model with model.module
1 parent dae8499 commit 09771ec

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

network/finetuner.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -137,34 +137,34 @@ def prepare_model(self, loading=False):
137137
# modify last layers based on the model being used
138138
if not loading:
139139
if self.model_name in ['alexnet', 'vgg']:
140-
num_features = self.model.classifier[6].in_features
140+
num_features = self.model.module.classifier[6].in_features
141141
if isinstance(self.criterion, LastLevelCELoss):
142-
self.model.classifier[6] = nn.Linear(num_features, self.levels[-1])
142+
self.model.module.classifier[6] = nn.Linear(num_features, self.levels[-1])
143143
elif isinstance(self.criterion, HierarchicalSoftmaxLoss):
144-
self.model.classifier[6] = HierarchicalSoftmax(labelmap=self.labelmap, input_size=num_features)
144+
self.model.module.classifier[6] = HierarchicalSoftmax(labelmap=self.labelmap, input_size=num_features)
145145
else:
146-
self.model.classifier[6] = nn.Linear(num_features, self.n_classes)
146+
self.model.module.classifier[6] = nn.Linear(num_features, self.n_classes)
147147
elif 'resnet' in self.model_name:
148-
num_features = self.model.fc.in_features
148+
num_features = self.model.module.fc.in_features
149149
if isinstance(self.criterion, LastLevelCELoss):
150-
self.model.fc = nn.Linear(num_features, self.levels[-1])
150+
self.model.module.fc = nn.Linear(num_features, self.levels[-1])
151151
elif isinstance(self.criterion, HierarchicalSoftmaxLoss):
152-
self.model.fc = HierarchicalSoftmax(labelmap=self.labelmap, input_size=num_features)
152+
self.model.module.fc = HierarchicalSoftmax(labelmap=self.labelmap, input_size=num_features)
153153
else:
154-
self.model.fc = nn.Linear(num_features, self.n_classes)
154+
self.model.module.fc = nn.Linear(num_features, self.n_classes)
155155
else:
156156
if self.model_name in ['alexnet', 'vgg']:
157-
num_features = self.model.module.classifier[6].in_features
157+
num_features = self.model.module.module.classifier[6].in_features
158158
if isinstance(self.criterion, LastLevelCELoss):
159-
self.model.module.classifier[6] = nn.Linear(num_features, self.levels[-1])
159+
self.model.module.module.classifier[6] = nn.Linear(num_features, self.levels[-1])
160160
elif isinstance(self.criterion, HierarchicalSoftmaxLoss):
161-
self.model.module.classifier[6] = HierarchicalSoftmax(labelmap=self.labelmap, input_size=num_features)
161+
self.model.module.module.classifier[6] = HierarchicalSoftmax(labelmap=self.labelmap, input_size=num_features)
162162
else:
163-
self.model.module.classifier[6] = nn.Linear(num_features, self.n_classes)
163+
self.model.module.module.classifier[6] = nn.Linear(num_features, self.n_classes)
164164
elif 'resnet' in self.model_name:
165-
num_features = self.model.module.fc.in_features
165+
num_features = self.model.module.module.fc.in_features
166166
if isinstance(self.criterion, LastLevelCELoss):
167-
self.model.module.fc = nn.Linear(num_features, self.levels[-1])
167+
self.model.module.module.fc = nn.Linear(num_features, self.levels[-1])
168168
elif isinstance(self.criterion, HierarchicalSoftmaxLoss):
169169
self.model.module.fc = HierarchicalSoftmax(labelmap=self.labelmap, input_size=num_features)
170170
else:

0 commit comments

Comments
 (0)