@@ -299,15 +299,15 @@ def slerp(val, low, high):
299
299
def gen_hinge_loss (fake , real ):
300
300
return fake .mean ()
301
301
302
- def hinge_loss (real , fake ):
302
+ def hinge_loss (fake , real ):
303
303
return (F .relu (1 + real ) + F .relu (1 - fake )).mean ()
304
304
305
- def dual_contrastive_loss (real_logits , fake_logits ):
305
+ def dual_contrastive_loss (fake_logits , real_logits ):
306
306
device = real_logits .device
307
307
real_logits , fake_logits = map (lambda t : rearrange (t , '... -> (...)' ), (real_logits , fake_logits ))
308
308
309
309
def loss_half (t1 , t2 ):
310
- t1 = rearrange (t1 , 'i -> i () ' )
310
+ t1 = rearrange (t1 , 'i -> i 1 ' )
311
311
t2 = repeat (t2 , 'j -> i j' , i = t1 .shape [0 ])
312
312
t = torch .cat ((t1 , t2 ), dim = - 1 )
313
313
return F .cross_entropy (t , torch .zeros (t1 .shape [0 ], device = device , dtype = torch .long ))
@@ -1043,7 +1043,7 @@ def train(self):
1043
1043
real_output_loss = real_output_loss - fake_output .mean ()
1044
1044
fake_output_loss = fake_output_loss - real_output .mean ()
1045
1045
1046
- divergence = D_loss_fn (real_output_loss , fake_output_loss )
1046
+ divergence = D_loss_fn (fake_output_loss , real_output_loss )
1047
1047
disc_loss = divergence
1048
1048
1049
1049
if self .has_fq :
0 commit comments