-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunderstand_cnn.py
108 lines (87 loc) · 4.2 KB
/
understand_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import time,os
import shutil
from datetime import datetime
from PIL import Image
from scipy import misc
import tensorflow as tf
import numpy as np
import vgg
tf.app.flags.DEFINE_integer("CONTENT_WEIGHT", 5e1, "Weight for content features loss")
tf.app.flags.DEFINE_integer("TV_WEIGHT", 1e-5, "Weight for total variation loss")
tf.app.flags.DEFINE_string("VGG_PATH", "/home/zhangxuesen/neuralstyle/imagenet-vgg-verydeep-19.mat",
"Path to vgg model weights")
tf.app.flags.DEFINE_float("LEARNING_RATE", 10., "Learning rate")
tf.app.flags.DEFINE_string("CONTENT_IMAGE", "cat.jpg", "Content image to use")
tf.app.flags.DEFINE_string("CONTENT_LAYERS", "hybrid_pool4","Which VGG layer to extract")
tf.app.flags.DEFINE_boolean("RANDOM_INIT", True, "Start from random noise")
tf.app.flags.DEFINE_integer("NUM_ITERATIONS", 5000, "Number of iterations")
FLAGS = tf.app.flags.FLAGS
def total_variation_loss(layer):
shape = tf.shape(layer)
height = shape[1]
width = shape[2]
y = tf.slice(layer, [0,0,0,0], tf.pack([-1,height-1,-1,-1])) - tf.slice(layer, [0,1,0,0], [-1,-1,-1,-1])
x = tf.slice(layer, [0,0,0,0], tf.pack([-1,-1,width-1,-1])) - tf.slice(layer, [0,0,1,0], [-1,-1,-1,-1])
return tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
def get_content_features(content_path, content_layers):
with tf.Graph().as_default() as g:
image = tf.expand_dims(tf.convert_to_tensor(misc.imread(content_path),tf.float32), 0)
net, _ = vgg.net(FLAGS.VGG_PATH, image)
layers = []
for layer in content_layers:
layers.append(net[layer])
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction=0.4
with tf.Session(config=config) as sess:
return sess.run(layers+[image])
def main(argv=None):
content_path = FLAGS.CONTENT_IMAGE
content_layers = FLAGS.CONTENT_LAYERS.split(',')
feat_and_image= get_content_features(content_path, content_layers)
content_features=feat_and_image[:-1]
image_t=feat_and_image[-1]
image = tf.constant(image_t)
random = tf.random_normal(image_t.shape)
initial = tf.Variable(random if FLAGS.RANDOM_INIT else image)
net, _ = vgg.net(FLAGS.VGG_PATH, initial)
content_loss = 0
for layer,feat in zip(content_layers,content_features):
layer_size = tf.size(feat)
content_loss += tf.nn.l2_loss(net[layer] - feat) / tf.to_float(layer_size)
content_loss = FLAGS.CONTENT_WEIGHT * content_loss / len(content_layers)
tv_loss = FLAGS.TV_WEIGHT * total_variation_loss(initial)
total_loss = content_loss + tv_loss
tf.scalar_summary('total_loss',total_loss)
tf.image_summary('genpic',initial)
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0), trainable=False)
lr = tf.train.exponential_decay(FLAGS.LEARNING_RATE,
global_step,
3000,
0.2,
staircase=True)
tf.scalar_summary('lr',lr)
opt = tf.train.MomentumOptimizer(FLAGS.LEARNING_RATE,0.9)
train_op = opt.minimize(total_loss,global_step=global_step)
output_image = tf.saturate_cast(tf.squeeze(initial), tf.uint8)
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction=0.4
with tf.Session(config=config) as sess:
merged = tf.merge_all_summaries()
tmp_dir = 'tmp/log_'+FLAGS.CONTENT_LAYERS
if os.path.exists(tmp_dir):shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir)
writer = tf.train.SummaryWriter(tmp_dir,sess.graph)
sess.run(tf.initialize_all_variables())
for step in xrange(FLAGS.NUM_ITERATIONS):
summary,_, loss_t = sess.run([merged,train_op, total_loss])
writer.add_summary(summary,step)
if step%10==0:
print('{} step {} with loss {}'.format(datetime.now(),step, loss_t))
if step>1500 and step%500==0:
image_t = sess.run(output_image)
misc.imsave('genimg/'+FLAGS.CONTENT_LAYERS+'_'+str(step)+'.jpg',
np.squeeze(image_t))
if __name__ == '__main__':
tf.app.run()