Skip to content

Commit 107e9e5

Browse files
committed
implement re-weighted loss; add as cmd line arguments
1 parent 22e70c9 commit 107e9e5

File tree

4 files changed

+70
-26
lines changed

4 files changed

+70
-26
lines changed

network/ethec_experiments.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def ETHEC_train_model(arguments):
121121
trainloader = torch.utils.data.DataLoader(train_set,
122122
batch_size=batch_size,
123123
num_workers=n_workers,
124-
shuffle=True)
125-
# sampler=WeightedResampler(train_set))
124+
shuffle=True if arguments.class_weights else False,
125+
sampler=None if arguments.class_weights else WeightedResampler(
126+
train_set))
126127

127128
valloader = torch.utils.data.DataLoader(val_set,
128129
batch_size=batch_size,
@@ -138,7 +139,9 @@ def ETHEC_train_model(arguments):
138139
trainloader = torch.utils.data.DataLoader(train_set,
139140
batch_size=batch_size,
140141
num_workers=n_workers,
141-
sampler=WeightedResampler(train_set))
142+
shuffle=True if arguments.class_weights else False,
143+
sampler=None if arguments.class_weights else WeightedResampler(
144+
train_set))
142145
valloader = torch.utils.data.DataLoader(val_set,
143146
batch_size=batch_size,
144147
shuffle=False, num_workers=n_workers)
@@ -148,16 +151,23 @@ def ETHEC_train_model(arguments):
148151

149152
data_loaders = {'train': trainloader, 'val': valloader, 'test': testloader}
150153

154+
weight = None
155+
if arguments.class_weights:
156+
n_train = torch.zeros(labelmap.n_classes)
157+
for data_item in data_loaders['train']:
158+
n_train += torch.sum(data_item['labels'], 0)
159+
weight = 1.0/n_train
160+
151161
eval_type = MultiLabelEvaluation(os.path.join(arguments.experiment_dir, arguments.experiment_name), labelmap)
152162
if arguments.evaluator == 'MLST':
153163
eval_type = MultiLabelEvaluationSingleThresh(os.path.join(arguments.experiment_dir, arguments.experiment_name),
154164
labelmap)
155165

156166
use_criterion = None
157167
if arguments.loss == 'multi_label':
158-
use_criterion = MultiLabelSMLoss()
168+
use_criterion = MultiLabelSMLoss(weight=weight)
159169
elif arguments.loss == 'multi_level':
160-
use_criterion = MultiLevelCELoss(labelmap=labelmap)
170+
use_criterion = MultiLevelCELoss(labelmap=labelmap, weight=weight)
161171
eval_type = MultiLevelEvaluation(os.path.join(arguments.experiment_dir, arguments.experiment_name), labelmap)
162172

163173
ETHEC_trainer = ETHECExperiment(data_loaders=data_loaders, labelmap=labelmap,
@@ -196,6 +206,7 @@ def ETHEC_train_model(arguments):
196206
parser.add_argument("--model", help='NN model to use. Use one of [`multi_label`, `multi_level`]',
197207
type=str, required=True)
198208
parser.add_argument("--loss", help='Loss function to use.', type=str, required=True)
209+
parser.add_argument("--class_weights", help='Re-weigh the loss function based on inverse class freq.', action='store_true')
199210
parser.add_argument("--freeze_weights", help='This flag fine tunes only the last layer.', action='store_true')
200211
parser.add_argument("--set_mode", help='If use training or testing mode (loads best model).', type=str,
201212
required=True)

network/fashion_mnist.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ def train_FMNIST(arguments):
7878
# transforms.Normalize(0.5, 0.5)
7979
])
8080

81-
lmap = labelmap_FMNIST()
81+
labelmap = labelmap_FMNIST()
8282
batch_size = arguments.batch_size
8383
n_workers = arguments.n_workers
8484

8585
if arguments.debug:
8686
print("== Running in DEBUG mode!")
87-
trainset = FMNISTHierarchical(root='../database', labelmap=lmap, train=False,
87+
trainset = FMNISTHierarchical(root='../database', labelmap=labelmap, train=False,
8888
download=True, transform=data_transforms)
8989
trainloader = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, list(range(100))), batch_size=batch_size,
9090
shuffle=True, num_workers=n_workers)
@@ -100,14 +100,14 @@ def train_FMNIST(arguments):
100100
data_loaders = {'train': trainloader, 'val': valloader, 'test': testloader}
101101

102102
else:
103-
trainset = FMNISTHierarchical(root='../database', labelmap=lmap, train=True,
103+
trainset = FMNISTHierarchical(root='../database', labelmap=labelmap, train=True,
104104
download=True, transform=data_transforms)
105-
testset = FMNISTHierarchical(root='../database', labelmap=lmap, train=False,
105+
testset = FMNISTHierarchical(root='../database', labelmap=labelmap, train=False,
106106
download=True, transform=data_transforms)
107107

108108
# split the dataset into 80:10:10
109109
train_indices_from_train, val_indices_from_train, val_indices_from_test, test_indices_from_test = \
110-
FMNIST_set_indices(trainset, testset, lmap)
110+
FMNIST_set_indices(trainset, testset, labelmap)
111111

112112
trainloader = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, train_indices_from_train),
113113
batch_size=batch_size,
@@ -125,18 +125,25 @@ def train_FMNIST(arguments):
125125

126126
data_loaders = {'train': trainloader, 'val': valloader, 'test': testloader}
127127

128-
eval_type = MultiLabelEvaluation(os.path.join(arguments.experiment_dir, arguments.experiment_name), lmap)
128+
weight = None
129+
if arguments.class_weights:
130+
n_train = torch.zeros(labelmap.n_classes)
131+
for data_item in data_loaders['train']:
132+
n_train += torch.sum(data_item['labels'], 0)
133+
weight = 1.0 / n_train
134+
135+
eval_type = MultiLabelEvaluation(os.path.join(arguments.experiment_dir, arguments.experiment_name), labelmap)
129136
if arguments.evaluator == 'MLST':
130-
eval_type = MultiLabelEvaluationSingleThresh(os.path.join(arguments.experiment_dir, arguments.experiment_name), lmap)
137+
eval_type = MultiLabelEvaluationSingleThresh(os.path.join(arguments.experiment_dir, arguments.experiment_name), labelmap)
131138

132139
use_criterion = None
133140
if arguments.loss == 'multi_label':
134-
use_criterion = MultiLabelSMLoss()
141+
use_criterion = MultiLabelSMLoss(weight=weight)
135142
elif arguments.loss == 'multi_level':
136-
use_criterion = MultiLevelCELoss(labelmap=lmap)
137-
eval_type = MultiLevelEvaluation(os.path.join(arguments.experiment_dir, arguments.experiment_name), lmap)
143+
use_criterion = MultiLevelCELoss(labelmap=labelmap, weight=weight)
144+
eval_type = MultiLevelEvaluation(os.path.join(arguments.experiment_dir, arguments.experiment_name), labelmap)
138145

139-
FMNIST_trainer = FMNIST(data_loaders=data_loaders, labelmap=lmap,
146+
FMNIST_trainer = FMNIST(data_loaders=data_loaders, labelmap=labelmap,
140147
criterion=use_criterion,
141148
lr=arguments.lr,
142149
batch_size=batch_size, evaluator=eval_type,
@@ -257,6 +264,7 @@ def FMNIST_set_indices(trainset, testset, labelmap=labelmap_FMNIST()):
257264
parser.add_argument("--resume", help='Continue training from last checkpoint.', action='store_true')
258265
parser.add_argument("--model", help='NN model to use.', type=str, required=True)
259266
parser.add_argument("--freeze_weights", help='This flag fine tunes only the last layer.', action='store_true')
267+
parser.add_argument("--class_weights", help='Re-weigh the loss function based on inverse class freq.', action='store_true')
260268
parser.add_argument("--set_mode", help='If use training or testing mode (loads best model).', type=str, required=True)
261269
parser.add_argument("--loss", help='Loss function to use.', type=str, required=True)
262270
args = parser.parse_args()

network/finetuner.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -409,16 +409,23 @@ def train_cifar10(arguments):
409409

410410
data_loaders = {'train': trainloader, 'val': valloader, 'test': testloader}
411411

412+
weight = None
413+
if arguments.class_weights:
414+
n_train = torch.zeros(labelmap.n_classes)
415+
for data_item in data_loaders['train']:
416+
n_train += torch.sum(data_item['labels'], 0)
417+
weight = 1.0 / n_train
418+
412419
eval_type = MultiLabelEvaluation(os.path.join(arguments.experiment_dir, arguments.experiment_name), labelmap)
413420
if arguments.evaluator == 'MLST':
414421
eval_type = MultiLabelEvaluationSingleThresh(os.path.join(arguments.experiment_dir, arguments.experiment_name),
415422
labelmap)
416423

417424
use_criterion = None
418425
if arguments.loss == 'multi_label':
419-
use_criterion = MultiLabelSMLoss()
426+
use_criterion = MultiLabelSMLoss(weight=weight)
420427
elif arguments.loss == 'multi_level':
421-
use_criterion = MultiLevelCELoss(labelmap=labelmap)
428+
use_criterion = MultiLevelCELoss(labelmap=labelmap, weight=weight)
422429
eval_type = MultiLevelEvaluation(os.path.join(arguments.experiment_dir, arguments.experiment_name), labelmap)
423430

424431
cifar_trainer = CIFAR10(data_loaders=data_loaders, labelmap=labelmap,
@@ -586,6 +593,7 @@ def train_alexnet_binary():
586593
parser.add_argument("--resume", help='Continue training from last checkpoint.', action='store_true')
587594
parser.add_argument("--model", help='NN model to use.', type=str, required=True)
588595
parser.add_argument("--loss", help='Loss function to use.', type=str, required=True)
596+
parser.add_argument("--class_weights", help='Re-weigh the loss function based on inverse class freq.', action='store_true')
589597
parser.add_argument("--freeze_weights", help='This flag fine tunes only the last layer.', action='store_true')
590598
parser.add_argument("--set_mode", help='If use training or testing mode (loads best model).', type=str,
591599
required=True)

network/loss.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,28 @@
44

55

66
class MultiLevelCELoss(torch.nn.Module):
7-
def __init__(self, labelmap, weights=None):
7+
def __init__(self, labelmap, level_weights=None, weight=None):
88
torch.nn.Module.__init__(self)
99
self.labelmap = labelmap
10-
self.weights = [1.0]*len(self.labelmap.levels) if weights is None else weights
11-
self.criterion = nn.CrossEntropyLoss(reduction='none')
12-
print('==Using the following weights config for multi level cross entropy loss: {}'.format(self.weights))
10+
self.level_weights = [1.0] * len(self.labelmap.levels) if level_weights is None else level_weights
11+
12+
self.criterion = []
13+
if weight is None:
14+
for level_len in self.labelmap.levels:
15+
self.criterion.append(nn.CrossEntropyLoss(weight=None, reduction='none'))
16+
else:
17+
level_stop, level_start = [], []
18+
for level_id, level_len in enumerate(self.labelmap.levels):
19+
if level_id == 0:
20+
level_start.append(0)
21+
level_stop.append(level_len)
22+
else:
23+
level_start.append(level_stop[level_id - 1])
24+
level_stop.append(level_stop[level_id - 1] + level_len)
25+
self.criterion.append(nn.CrossEntropyLoss(weight=weight[level_start[level_id]:level_stop[level_id]],
26+
reduction='none'))
27+
28+
print('==Using the following weights config for multi level cross entropy loss: {}'.format(self.level_weights))
1329

1430
def forward(self, outputs, labels, level_labels):
1531
# print('Outputs: {}'.format(outputs))
@@ -18,23 +34,24 @@ def forward(self, outputs, labels, level_labels):
1834
loss = 0.0
1935
for level_id, level in enumerate(self.labelmap.levels):
2036
if level_id == 0:
21-
loss += self.weights[level_id] * self.criterion(outputs[:, 0:level], level_labels[:, level_id])
37+
loss += self.level_weights[level_id] * self.criterion[level_id](outputs[:, 0:level], level_labels[:, level_id])
2238
# print(self.weights[level_id] * self.criterion(outputs[:, 0:level], level_labels[:, level_id]))
2339
else:
2440
start = sum([self.labelmap.levels[l_id] for l_id in range(level_id)])
2541
# print([self.labelmap.levels[l_id] for l_id in range(level_id)], level)
2642
# print(outputs[:, start:start+level])
2743
# print(self.weights[level_id] * self.criterion(outputs[:, start:start+level],
2844
# level_labels[:, level_id]))
29-
loss += self.weights[level_id] * self.criterion(outputs[:, start:start+level],
30-
level_labels[:, level_id])
45+
loss += self.level_weights[level_id] * self.criterion[level_id](outputs[:, start:start + level],
46+
level_labels[:, level_id])
3147
# print('Loss per sample: {}'.format(loss))
3248
# print('Avg loss: {}'.format(torch.mean(loss)))
3349
return torch.mean(loss)
3450

3551

3652
class MultiLabelSMLoss(torch.nn.MultiLabelSoftMarginLoss):
3753
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
54+
print(weight)
3855
torch.nn.MultiLabelSoftMarginLoss.__init__(self, weight, size_average, reduce, reduction)
3956

4057
def forward(self, outputs, labels, level_labels):
@@ -43,7 +60,7 @@ def forward(self, outputs, labels, level_labels):
4360

4461
if __name__ == '__main__':
4562
lmap = ETHECLabelMap()
46-
criterion = MultiLevelCELoss(labelmap=lmap, weights=[1, 1, 1, 1])
63+
criterion = MultiLevelCELoss(labelmap=lmap, level_weights=[1, 1, 1, 1])
4764
output, level_labels = torch.zeros((1, lmap.n_classes)), torch.tensor([[0,
4865
7-lmap.levels[0],
4966
90-(lmap.levels[0]+lmap.levels[1]),

0 commit comments

Comments
 (0)