Skip to content

Commit cbcd023

Browse files
authored
Add -normalize_gradients parameter
1 parent 97081c3 commit cbcd023

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ path or a full absolute path.
195195
when using ADAM you will probably need to play with other parameters to get good results, especially
196196
the style weight, content weight, and learning rate.
197197
* `-learning_rate`: Learning rate to use with the ADAM optimizer. Default is 1e1.
198+
* `-normalize_gradients`: If this flag is present, style and content gradients from each layer will be L1 normalized.
198199

199200
**Output options**:
200201
* `-output_image`: Name of the output image. Default is `out.png`.
@@ -313,4 +314,4 @@ If you find this code useful for your research, please cite:
313314
journal = {GitHub repository},
314315
howpublished = {\url{https://github.com/ProGamerGov/neural-style-pt}},
315316
}
316-
```
317+
```

neural_style.py

+34-11
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
parser.add_argument("-content_weight", type=float, default=5e0)
2222
parser.add_argument("-style_weight", type=float, default=1e2)
2323
parser.add_argument("-normalize_weights", action='store_true')
24+
parser.add_argument("-normalize_gradients", action='store_true')
2425
parser.add_argument("-tv_weight", type=float, default=1e-3)
2526
parser.add_argument("-num_iterations", type=int, default=1000)
2627
parser.add_argument("-init", choices=['random', 'image'], default='random')
@@ -121,13 +122,13 @@ def main():
121122

122123
if layerList['C'][c] in content_layers:
123124
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
124-
loss_module = ContentLoss(params.content_weight)
125+
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
125126
net.add_module(str(len(net)), loss_module)
126127
content_losses.append(loss_module)
127128

128129
if layerList['C'][c] in style_layers:
129130
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
130-
loss_module = StyleLoss(params.style_weight)
131+
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
131132
net.add_module(str(len(net)), loss_module)
132133
style_losses.append(loss_module)
133134
c+=1
@@ -137,14 +138,14 @@ def main():
137138

138139
if layerList['R'][r] in content_layers:
139140
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
140-
loss_module = ContentLoss(params.content_weight)
141+
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
141142
net.add_module(str(len(net)), loss_module)
142143
content_losses.append(loss_module)
143144
next_content_idx += 1
144145

145146
if layerList['R'][r] in style_layers:
146147
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
147-
loss_module = StyleLoss(params.style_weight)
148+
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
148149
net.add_module(str(len(net)), loss_module)
149150
style_losses.append(loss_module)
150151
next_style_idx += 1
@@ -339,15 +340,15 @@ def preprocess(image_name, image_size):
339340
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
340341
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
341342
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
342-
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
343+
tensor = Normalize(rgb2bgr(Loader(image) * 255)).unsqueeze(0)
343344
return tensor
344345

345346

346347
# Undo the above preprocessing.
347348
def deprocess(output_tensor):
348349
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
349350
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
350-
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
351+
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 255
351352
output_tensor.clamp_(0, 1)
352353
Image2PIL = transforms.ToPILImage()
353354
image = Image2PIL(output_tensor.cpu())
@@ -399,18 +400,36 @@ def normalize_weights(content_losses, style_losses):
399400
i.strength = i.strength / max(i.target.size())
400401

401402

403+
# Scale gradients in the backward pass
404+
class ScaleGradients(torch.autograd.Function):
405+
@staticmethod
406+
def forward(self, input_tensor, strength):
407+
self.strength = strength
408+
return input_tensor
409+
410+
@staticmethod
411+
def backward(self, grad_output):
412+
grad_input = grad_output.clone()
413+
grad_input = grad_input / (torch.norm(grad_input, keepdim=True) + 1e-8)
414+
return grad_input * self.strength * self.strength, None
415+
416+
402417
# Define an nn Module to compute content loss
403418
class ContentLoss(nn.Module):
404419

405-
def __init__(self, strength):
420+
def __init__(self, strength, normalize):
406421
super(ContentLoss, self).__init__()
407422
self.strength = strength
408423
self.crit = nn.MSELoss()
409424
self.mode = 'None'
425+
self.normalize = normalize
410426

411427
def forward(self, input):
412428
if self.mode == 'loss':
413-
self.loss = self.crit(input, self.target) * self.strength
429+
loss = self.crit(input, self.target)
430+
if self.normalize:
431+
loss = ScaleGradients.apply(loss, self.strength)
432+
self.loss = loss * self.strength
414433
elif self.mode == 'capture':
415434
self.target = input.detach()
416435
return input
@@ -427,14 +446,15 @@ def forward(self, input):
427446
# Define an nn Module to compute style loss
428447
class StyleLoss(nn.Module):
429448

430-
def __init__(self, strength):
449+
def __init__(self, strength, normalize):
431450
super(StyleLoss, self).__init__()
432451
self.target = torch.Tensor()
433452
self.strength = strength
434453
self.gram = GramMatrix()
435454
self.crit = nn.MSELoss()
436455
self.mode = 'None'
437456
self.blend_weight = None
457+
self.normalize = normalize
438458

439459
def forward(self, input):
440460
self.G = self.gram(input)
@@ -447,7 +467,10 @@ def forward(self, input):
447467
else:
448468
self.target = self.target.add(self.blend_weight, self.G.detach())
449469
elif self.mode == 'loss':
450-
self.loss = self.strength * self.crit(self.G, self.target)
470+
loss = self.crit(self.G, self.target)
471+
if self.normalize:
472+
loss = ScaleGradients.apply(loss, self.strength)
473+
self.loss = self.strength * loss
451474
return input
452475

453476

@@ -465,4 +488,4 @@ def forward(self, input):
465488

466489

467490
if __name__ == "__main__":
468-
main()
491+
main()

0 commit comments

Comments
 (0)