21
21
parser .add_argument ("-content_weight" , type = float , default = 5e0 )
22
22
parser .add_argument ("-style_weight" , type = float , default = 1e2 )
23
23
parser .add_argument ("-normalize_weights" , action = 'store_true' )
24
+ parser .add_argument ("-normalize_gradients" , action = 'store_true' )
24
25
parser .add_argument ("-tv_weight" , type = float , default = 1e-3 )
25
26
parser .add_argument ("-num_iterations" , type = int , default = 1000 )
26
27
parser .add_argument ("-init" , choices = ['random' , 'image' ], default = 'random' )
@@ -121,13 +122,13 @@ def main():
121
122
122
123
if layerList ['C' ][c ] in content_layers :
123
124
print ("Setting up content layer " + str (i ) + ": " + str (layerList ['C' ][c ]))
124
- loss_module = ContentLoss (params .content_weight )
125
+ loss_module = ContentLoss (params .content_weight , params . normalize_gradients )
125
126
net .add_module (str (len (net )), loss_module )
126
127
content_losses .append (loss_module )
127
128
128
129
if layerList ['C' ][c ] in style_layers :
129
130
print ("Setting up style layer " + str (i ) + ": " + str (layerList ['C' ][c ]))
130
- loss_module = StyleLoss (params .style_weight )
131
+ loss_module = StyleLoss (params .style_weight , params . normalize_gradients )
131
132
net .add_module (str (len (net )), loss_module )
132
133
style_losses .append (loss_module )
133
134
c += 1
@@ -137,14 +138,14 @@ def main():
137
138
138
139
if layerList ['R' ][r ] in content_layers :
139
140
print ("Setting up content layer " + str (i ) + ": " + str (layerList ['R' ][r ]))
140
- loss_module = ContentLoss (params .content_weight )
141
+ loss_module = ContentLoss (params .content_weight , params . normalize_gradients )
141
142
net .add_module (str (len (net )), loss_module )
142
143
content_losses .append (loss_module )
143
144
next_content_idx += 1
144
145
145
146
if layerList ['R' ][r ] in style_layers :
146
147
print ("Setting up style layer " + str (i ) + ": " + str (layerList ['R' ][r ]))
147
- loss_module = StyleLoss (params .style_weight )
148
+ loss_module = StyleLoss (params .style_weight , params . normalize_gradients )
148
149
net .add_module (str (len (net )), loss_module )
149
150
style_losses .append (loss_module )
150
151
next_style_idx += 1
@@ -339,15 +340,15 @@ def preprocess(image_name, image_size):
339
340
Loader = transforms .Compose ([transforms .Resize (image_size ), transforms .ToTensor ()])
340
341
rgb2bgr = transforms .Compose ([transforms .Lambda (lambda x : x [torch .LongTensor ([2 ,1 ,0 ])])])
341
342
Normalize = transforms .Compose ([transforms .Normalize (mean = [103.939 , 116.779 , 123.68 ], std = [1 ,1 ,1 ])])
342
- tensor = Normalize (rgb2bgr (Loader (image ) * 256 )).unsqueeze (0 )
343
+ tensor = Normalize (rgb2bgr (Loader (image ) * 255 )).unsqueeze (0 )
343
344
return tensor
344
345
345
346
346
347
# Undo the above preprocessing.
347
348
def deprocess (output_tensor ):
348
349
Normalize = transforms .Compose ([transforms .Normalize (mean = [- 103.939 , - 116.779 , - 123.68 ], std = [1 ,1 ,1 ])])
349
350
bgr2rgb = transforms .Compose ([transforms .Lambda (lambda x : x [torch .LongTensor ([2 ,1 ,0 ])])])
350
- output_tensor = bgr2rgb (Normalize (output_tensor .squeeze (0 ).cpu ())) / 256
351
+ output_tensor = bgr2rgb (Normalize (output_tensor .squeeze (0 ).cpu ())) / 255
351
352
output_tensor .clamp_ (0 , 1 )
352
353
Image2PIL = transforms .ToPILImage ()
353
354
image = Image2PIL (output_tensor .cpu ())
@@ -399,18 +400,36 @@ def normalize_weights(content_losses, style_losses):
399
400
i .strength = i .strength / max (i .target .size ())
400
401
401
402
403
+ # Scale gradients in the backward pass
404
+ class ScaleGradients (torch .autograd .Function ):
405
+ @staticmethod
406
+ def forward (self , input_tensor , strength ):
407
+ self .strength = strength
408
+ return input_tensor
409
+
410
+ @staticmethod
411
+ def backward (self , grad_output ):
412
+ grad_input = grad_output .clone ()
413
+ grad_input = grad_input / (torch .norm (grad_input , keepdim = True ) + 1e-8 )
414
+ return grad_input * self .strength * self .strength , None
415
+
416
+
402
417
# Define an nn Module to compute content loss
403
418
class ContentLoss (nn .Module ):
404
419
405
- def __init__ (self , strength ):
420
+ def __init__ (self , strength , normalize ):
406
421
super (ContentLoss , self ).__init__ ()
407
422
self .strength = strength
408
423
self .crit = nn .MSELoss ()
409
424
self .mode = 'None'
425
+ self .normalize = normalize
410
426
411
427
def forward (self , input ):
412
428
if self .mode == 'loss' :
413
- self .loss = self .crit (input , self .target ) * self .strength
429
+ loss = self .crit (input , self .target )
430
+ if self .normalize :
431
+ loss = ScaleGradients .apply (loss , self .strength )
432
+ self .loss = loss * self .strength
414
433
elif self .mode == 'capture' :
415
434
self .target = input .detach ()
416
435
return input
@@ -427,14 +446,15 @@ def forward(self, input):
427
446
# Define an nn Module to compute style loss
428
447
class StyleLoss (nn .Module ):
429
448
430
- def __init__ (self , strength ):
449
+ def __init__ (self , strength , normalize ):
431
450
super (StyleLoss , self ).__init__ ()
432
451
self .target = torch .Tensor ()
433
452
self .strength = strength
434
453
self .gram = GramMatrix ()
435
454
self .crit = nn .MSELoss ()
436
455
self .mode = 'None'
437
456
self .blend_weight = None
457
+ self .normalize = normalize
438
458
439
459
def forward (self , input ):
440
460
self .G = self .gram (input )
@@ -447,7 +467,10 @@ def forward(self, input):
447
467
else :
448
468
self .target = self .target .add (self .blend_weight , self .G .detach ())
449
469
elif self .mode == 'loss' :
450
- self .loss = self .strength * self .crit (self .G , self .target )
470
+ loss = self .crit (self .G , self .target )
471
+ if self .normalize :
472
+ loss = ScaleGradients .apply (loss , self .strength )
473
+ self .loss = self .strength * loss
451
474
return input
452
475
453
476
@@ -465,4 +488,4 @@ def forward(self, input):
465
488
466
489
467
490
if __name__ == "__main__" :
468
- main ()
491
+ main ()
0 commit comments