@@ -232,14 +232,14 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
232
232
else :
233
233
loss .backward (** kwargs )
234
234
235
- def gradient_penalty (images , output , weight = 10 ):
235
+ def gradient_penalty (images , output , weight = 10 , center = 0. ):
236
236
batch_size = images .shape [0 ]
237
237
gradients = torch_grad (outputs = output , inputs = images ,
238
238
grad_outputs = torch .ones (output .size (), device = images .device ),
239
239
create_graph = True , retain_graph = True , only_inputs = True )[0 ]
240
240
241
241
gradients = gradients .reshape (batch_size , - 1 )
242
- return weight * ((gradients .norm (2 , dim = 1 ) - 1 ) ** 2 ).mean ()
242
+ return weight * ((gradients .norm (2 , dim = 1 ) - center ) ** 2 ).mean ()
243
243
244
244
def calc_pl_lengths (styles , images ):
245
245
device = images .device
@@ -396,15 +396,23 @@ def __init__(self, D, image_size):
396
396
super ().__init__ ()
397
397
self .D = D
398
398
399
- def forward (self , images , prob = 0. , types = [], detach = False ):
399
+ def forward (self , images , prob = 0. , types = [], detach = False , return_aug_images = False , input_requires_grad = False ):
400
400
if random () < prob :
401
401
images = random_hflip (images , prob = 0.5 )
402
402
images = DiffAugment (images , types = types )
403
403
404
404
if detach :
405
405
images = images .detach ()
406
406
407
- return self .D (images )
407
+ if input_requires_grad :
408
+ images .requires_grad_ ()
409
+
410
+ logits = self .D (images )
411
+
412
+ if not return_aug_images :
413
+ return logits
414
+
415
+ return images , logits
408
416
409
417
# stylegan2 classes
410
418
@@ -1030,10 +1038,13 @@ def train(self):
1030
1038
w_styles = styles_def_to_tensor (w_space )
1031
1039
1032
1040
generated_images = G (w_styles , noise )
1033
- fake_output , fake_q_loss = D_aug (generated_images .clone ().detach (), detach = True , ** aug_kwargs )
1041
+ generated_images , ( fake_output , fake_q_loss ) = D_aug (generated_images .clone ().detach (), return_aug_images = True , input_requires_grad = apply_gradient_penalty , detach = True , ** aug_kwargs )
1034
1042
1035
1043
image_batch = next (self .loader ).cuda (self .rank )
1036
- image_batch .requires_grad_ ()
1044
+
1045
+ if apply_gradient_penalty :
1046
+ image_batch .requires_grad_ ()
1047
+
1037
1048
real_output , real_q_loss = D_aug (image_batch , ** aug_kwargs )
1038
1049
1039
1050
real_output_loss = real_output
@@ -1053,7 +1064,7 @@ def train(self):
1053
1064
disc_loss = disc_loss + quantize_loss
1054
1065
1055
1066
if apply_gradient_penalty :
1056
- gp = gradient_penalty (image_batch , real_output )
1067
+ gp = gradient_penalty (image_batch , real_output ) + gradient_penalty ( generated_images , fake_output )
1057
1068
self .last_gp_loss = gp .clone ().detach ().item ()
1058
1069
self .track (self .last_gp_loss , 'GP' )
1059
1070
disc_loss = disc_loss + gp
@@ -1382,7 +1393,7 @@ def load(self, num = -1):
1382
1393
1383
1394
self .steps = name * self .save_every
1384
1395
1385
- load_data = torch .load (self .model_name (name ))
1396
+ load_data = torch .load (self .model_name (name ), weights_only = True )
1386
1397
1387
1398
if 'version' in load_data :
1388
1399
print (f"loading from version { load_data ['version' ]} " )
0 commit comments