Skip to content

Commit 5ff7b57

Browse files
committed
address inconsistent ordering of fake/real logits going into losses for dual contrastive loss, #289
1 parent e47fafa commit 5ff7b57

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,15 +299,15 @@ def slerp(val, low, high):
299299
def gen_hinge_loss(fake, real):
300300
return fake.mean()
301301

302-
def hinge_loss(real, fake):
302+
def hinge_loss(fake, real):
303303
return (F.relu(1 + real) + F.relu(1 - fake)).mean()
304304

305-
def dual_contrastive_loss(real_logits, fake_logits):
305+
def dual_contrastive_loss(fake_logits, real_logits):
306306
device = real_logits.device
307307
real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))
308308

309309
def loss_half(t1, t2):
310-
t1 = rearrange(t1, 'i -> i ()')
310+
t1 = rearrange(t1, 'i -> i 1')
311311
t2 = repeat(t2, 'j -> i j', i = t1.shape[0])
312312
t = torch.cat((t1, t2), dim = -1)
313313
return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long))
@@ -1043,7 +1043,7 @@ def train(self):
10431043
real_output_loss = real_output_loss - fake_output.mean()
10441044
fake_output_loss = fake_output_loss - real_output.mean()
10451045

1046-
divergence = D_loss_fn(real_output_loss, fake_output_loss)
1046+
divergence = D_loss_fn(fake_output_loss, real_output_loss)
10471047
disc_loss = divergence
10481048

10491049
if self.has_fq:

stylegan2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.8.10'
1+
__version__ = '1.8.11'

0 commit comments

Comments
 (0)