-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss.py
67 lines (55 loc) · 2.24 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
class CE:
def __init__(self, model,adjust):
self.model = model
self.ce = nn.CrossEntropyLoss()
self.ce_pretrain = nn.CrossEntropyLoss(ignore_index=0)
self.adjust = adjust
def compute(self, batch):
seqs, labels = batch
if self.adjust:
seqs = seqs.reshape(labels.size()[0], -1)
out1 = self.model[0](seqs)
out1 = out1.reshape(labels.size()[0], 128,-1)
outputs = self.model[1](out1,hierarchical=False)
# outputs = self.model[2](outputs)
else :
outputs = self.model(seqs,hierarchical=False) # B * N
labels = labels.view(-1).long()
loss = self.ce(outputs, labels)
return loss
class Align:
def __init__(self):
self.mse = nn.MSELoss(reduction='mean')
self.ce = nn.CrossEntropyLoss()
def compute(self, rep_mask, rep_mask_prediction):
align_loss = self.mse(rep_mask, rep_mask_prediction)
return align_loss
class Reconstruct:
def __init__(self):
self.ce = nn.CrossEntropyLoss()
def compute(self, token_prediction_prob, tokens):
hits = torch.sum(torch.argmax(token_prediction_prob, dim=-1) == tokens)
NDCG10 = recalls_and_ndcgs_for_ks(token_prediction_prob.view(-1, token_prediction_prob.shape[-1]),
tokens.reshape(-1, 1), 10)
reconstruct_loss = self.ce(token_prediction_prob.view(-1, token_prediction_prob.shape[-1]), tokens.view(-1))
return reconstruct_loss, hits, NDCG10
def recalls_and_ndcgs_for_ks(scores, answers, k):
answers = answers.tolist()
labels = torch.zeros_like(scores).to(scores.device)
for i in range(len(answers)):
labels[i][answers[i]] = 1
answer_count = labels.sum(1)
labels_float = labels.float()
rank = (-scores).argsort(dim=1)
cut = rank
cut = cut[:, :k]
hits = labels_float.gather(1, cut)
position = torch.arange(2, 2 + k)
weights = 1 / torch.log2(position.float())
dcg = (hits * weights.to(hits.device)).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum() for n in answer_count]).to(dcg.device)
ndcg = (dcg / idcg).mean()
ndcg = ndcg.cpu().item()
return ndcg