Skip to content

Commit 6d6db3a

Browse files
committed
switch to sparse one hot label
1 parent fc748a1 commit 6d6db3a

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Other temp files
2+
*.swp
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

cross_entropy.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,10 @@ def forward(ctx, *args):
5353
if ctx.compute_loss:
5454
_loss_list = []
5555
for gpu_id, softmax in enumerate(softmax_list):
56-
if isinstance(ctx.label_split[gpu_id], torch.sparse.LongTensor):
57-
idx = ctx.label_split[gpu_id]._indices()
58-
# FIXME move _loss to gpu?
59-
_loss = torch.zeros(ctx.batch_size)
60-
_loss.scatter_(dim=0, index=idx[0], src=softmax[tuple(idx)])
61-
_loss_list.append(_loss)
62-
else:
63-
_loss = torch.sum(softmax * ctx.label_split[gpu_id], dim=1)
64-
_loss_list.append(_loss)
56+
idx = ctx.label_split[gpu_id]._indices()
57+
_loss = torch.zeros(ctx.batch_size).to(gpu_id)
58+
_loss.scatter_(dim=0, index=idx[0], src=softmax[tuple(idx)])
59+
_loss_list.append(_loss)
6560
_loss = comm.reduce_add(_loss_list, destination=0)
6661
log_loss = -torch.log(_loss)
6762
loss = torch.mean(log_loss)

train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from model import ft_net
1616
from cross_entropy import ModelParallelCrossEntropy
17-
from utils import get_class_split, get_onehot_label, compute_batch_acc
17+
from utils import get_class_split, get_sparse_onehot_label, compute_batch_acc
1818

1919

2020
def get_data_loader(data_path, batch_size):
@@ -45,7 +45,7 @@ def train_model(opt, data_loader, model, criterion, optimizer, class_split):
4545
images, labels = data_loader_iter.next()
4646
images = images.cuda(0)
4747
labels = labels.cuda(0)
48-
onehot_labels = get_onehot_label(labels, opt.num_gpus, opt.num_classes, opt.model_parallel, class_split)
48+
onehot_labels = get_sparse_onehot_label(labels, opt.num_gpus, opt.num_classes, opt.model_parallel, class_split)
4949
# Forward
5050
optimizer.zero_grad()
5151
logits = model(images, labels=onehot_labels)

0 commit comments

Comments
 (0)