@@ -78,13 +78,13 @@ def train_FMNIST(arguments):
78
78
# transforms.Normalize(0.5, 0.5)
79
79
])
80
80
81
- lmap = labelmap_FMNIST ()
81
+ labelmap = labelmap_FMNIST ()
82
82
batch_size = arguments .batch_size
83
83
n_workers = arguments .n_workers
84
84
85
85
if arguments .debug :
86
86
print ("== Running in DEBUG mode!" )
87
- trainset = FMNISTHierarchical (root = '../database' , labelmap = lmap , train = False ,
87
+ trainset = FMNISTHierarchical (root = '../database' , labelmap = labelmap , train = False ,
88
88
download = True , transform = data_transforms )
89
89
trainloader = torch .utils .data .DataLoader (torch .utils .data .Subset (trainset , list (range (100 ))), batch_size = batch_size ,
90
90
shuffle = True , num_workers = n_workers )
@@ -100,14 +100,14 @@ def train_FMNIST(arguments):
100
100
data_loaders = {'train' : trainloader , 'val' : valloader , 'test' : testloader }
101
101
102
102
else :
103
- trainset = FMNISTHierarchical (root = '../database' , labelmap = lmap , train = True ,
103
+ trainset = FMNISTHierarchical (root = '../database' , labelmap = labelmap , train = True ,
104
104
download = True , transform = data_transforms )
105
- testset = FMNISTHierarchical (root = '../database' , labelmap = lmap , train = False ,
105
+ testset = FMNISTHierarchical (root = '../database' , labelmap = labelmap , train = False ,
106
106
download = True , transform = data_transforms )
107
107
108
108
# split the dataset into 80:10:10
109
109
train_indices_from_train , val_indices_from_train , val_indices_from_test , test_indices_from_test = \
110
- FMNIST_set_indices (trainset , testset , lmap )
110
+ FMNIST_set_indices (trainset , testset , labelmap )
111
111
112
112
trainloader = torch .utils .data .DataLoader (torch .utils .data .Subset (trainset , train_indices_from_train ),
113
113
batch_size = batch_size ,
@@ -125,18 +125,25 @@ def train_FMNIST(arguments):
125
125
126
126
data_loaders = {'train' : trainloader , 'val' : valloader , 'test' : testloader }
127
127
128
- eval_type = MultiLabelEvaluation (os .path .join (arguments .experiment_dir , arguments .experiment_name ), lmap )
128
+ weight = None
129
+ if arguments .class_weights :
130
+ n_train = torch .zeros (labelmap .n_classes )
131
+ for data_item in data_loaders ['train' ]:
132
+ n_train += torch .sum (data_item ['labels' ], 0 )
133
+ weight = 1.0 / n_train
134
+
135
+ eval_type = MultiLabelEvaluation (os .path .join (arguments .experiment_dir , arguments .experiment_name ), labelmap )
129
136
if arguments .evaluator == 'MLST' :
130
- eval_type = MultiLabelEvaluationSingleThresh (os .path .join (arguments .experiment_dir , arguments .experiment_name ), lmap )
137
+ eval_type = MultiLabelEvaluationSingleThresh (os .path .join (arguments .experiment_dir , arguments .experiment_name ), labelmap )
131
138
132
139
use_criterion = None
133
140
if arguments .loss == 'multi_label' :
134
- use_criterion = MultiLabelSMLoss ()
141
+ use_criterion = MultiLabelSMLoss (weight = weight )
135
142
elif arguments .loss == 'multi_level' :
136
- use_criterion = MultiLevelCELoss (labelmap = lmap )
137
- eval_type = MultiLevelEvaluation (os .path .join (arguments .experiment_dir , arguments .experiment_name ), lmap )
143
+ use_criterion = MultiLevelCELoss (labelmap = labelmap , weight = weight )
144
+ eval_type = MultiLevelEvaluation (os .path .join (arguments .experiment_dir , arguments .experiment_name ), labelmap )
138
145
139
- FMNIST_trainer = FMNIST (data_loaders = data_loaders , labelmap = lmap ,
146
+ FMNIST_trainer = FMNIST (data_loaders = data_loaders , labelmap = labelmap ,
140
147
criterion = use_criterion ,
141
148
lr = arguments .lr ,
142
149
batch_size = batch_size , evaluator = eval_type ,
@@ -257,6 +264,7 @@ def FMNIST_set_indices(trainset, testset, labelmap=labelmap_FMNIST()):
257
264
parser .add_argument ("--resume" , help = 'Continue training from last checkpoint.' , action = 'store_true' )
258
265
parser .add_argument ("--model" , help = 'NN model to use.' , type = str , required = True )
259
266
parser .add_argument ("--freeze_weights" , help = 'This flag fine tunes only the last layer.' , action = 'store_true' )
267
+ parser .add_argument ("--class_weights" , help = 'Re-weigh the loss function based on inverse class freq.' , action = 'store_true' )
260
268
parser .add_argument ("--set_mode" , help = 'If use training or testing mode (loads best model).' , type = str , required = True )
261
269
parser .add_argument ("--loss" , help = 'Loss function to use.' , type = str , required = True )
262
270
args = parser .parse_args ()
0 commit comments