Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def run_style_transfer(content_image_filename,
- style_image_filename,
- initialImage=None,
- alpha=10,beta=1e-1,
- epochs=1200,
- learning_rate=5.0,
- prefix=None):
- if prefix is None:
- print('Output file prefix not defined...')
- return
- content_image = load_image(content_image_filename)
- style_image = load_image(style_image_filename)
- weights_dict = getModelWeightsAsDict()
- tf.reset_default_graph()
- layers = get_nst_model(weights_dict)
- print('Model graph generated...\n')
- # Calculate tensor for content loss
- J_content = 0.0
- with tf.Session() as sess:
- for content_layer_name, weight in CONTENT_LAYERS:
- content_layer = layers[content_layer_name]
- content_target = sess.run(content_layer,feed_dict={layers['input']:content_image})
- J_content = J_content + weight * getContentLoss(layers,content_target, content_layer_name)
- print('Content loss defined...\n')
- # Calculate tensor for style loss
- J_style = 0.0
- with tf.Session() as sess:
- for style_layer_name, weight in STYLE_LAYERS:
- style_layer = layers[style_layer_name]
- style_target = sess.run(style_layer,feed_dict={layers['input']:style_image})
- J_style = J_style + weight * getStyleLoss(layers,style_target,style_layer_name)
- print('Style loss defined...\n')
- J_total = getTotalLoss(J_content,J_style,alpha=alpha,beta=beta)
- optimizer = tf.train.AdamOptimizer(learning_rate).minimize(J_total)
- print('Losses defined...\n')
- print('Trainable variables:\n',tf.trainable_variables())
- with tf.Session() as sess:
- # Initialize image
- sess.run(tf.global_variables_initializer())
- if initialImage is not None:
- sess.run(layers['input'].assign(initialImage))
- else:
- sess.run(layers['input'].assign(generate_noise_image()))
- for epoch in range(epochs):
- epoch_loss, epoch_content_loss, epoch_style_loss, _ = sess.run([J_total,J_content,J_style,optimizer])
- if (epoch+1) % 100 == 0:
- generated_image = sess.run(layers['input'])
- generated_image = save_image(generated_image,prefix + '_' + str(epoch+1) + '.jpg')
- print('Loss after epoch %d: \nT: %f, \nC: %f, \nS: %f'%(epoch,epoch_loss,epoch_content_loss,epoch_style_loss))
- generated_image = cv2.cvtColor(generated_image, cv2.COLOR_RGB2BGR)
- cv2_imshow(generated_image)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement