Advertisement
Guest User

Untitled

a guest
Jun 15th, 2019
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.42 KB | None | 0 0
  1. def run_style_transfer(content_image_filename,
  2. style_image_filename,
  3. initialImage=None,
  4. alpha=10,beta=1e-1,
  5. epochs=1200,
  6. learning_rate=5.0,
  7. prefix=None):
  8.  
  9. if prefix is None:
  10. print('Output file prefix not defined...')
  11. return
  12.  
  13. content_image = load_image(content_image_filename)
  14. style_image = load_image(style_image_filename)
  15.  
  16. weights_dict = getModelWeightsAsDict()
  17. tf.reset_default_graph()
  18. layers = get_nst_model(weights_dict)
  19.  
  20. print('Model graph generated...\n')
  21.  
  22. # Calculate tensor for content loss
  23. J_content = 0.0
  24. with tf.Session() as sess:
  25. for content_layer_name, weight in CONTENT_LAYERS:
  26. content_layer = layers[content_layer_name]
  27. content_target = sess.run(content_layer,feed_dict={layers['input']:content_image})
  28. J_content = J_content + weight * getContentLoss(layers,content_target, content_layer_name)
  29. print('Content loss defined...\n')
  30.  
  31.  
  32. # Calculate tensor for style loss
  33. J_style = 0.0
  34. with tf.Session() as sess:
  35. for style_layer_name, weight in STYLE_LAYERS:
  36. style_layer = layers[style_layer_name]
  37. style_target = sess.run(style_layer,feed_dict={layers['input']:style_image})
  38. J_style = J_style + weight * getStyleLoss(layers,style_target,style_layer_name)
  39. print('Style loss defined...\n')
  40.  
  41. J_total = getTotalLoss(J_content,J_style,alpha=alpha,beta=beta)
  42.  
  43. optimizer = tf.train.AdamOptimizer(learning_rate).minimize(J_total)
  44.  
  45. print('Losses defined...\n')
  46.  
  47. print('Trainable variables:\n',tf.trainable_variables())
  48.  
  49. with tf.Session() as sess:
  50.  
  51. # Initialize image
  52. sess.run(tf.global_variables_initializer())
  53. if initialImage is not None:
  54. sess.run(layers['input'].assign(initialImage))
  55. else:
  56. sess.run(layers['input'].assign(generate_noise_image()))
  57.  
  58. for epoch in range(epochs):
  59. epoch_loss, epoch_content_loss, epoch_style_loss, _ = sess.run([J_total,J_content,J_style,optimizer])
  60.  
  61. if (epoch+1) % 100 == 0:
  62. generated_image = sess.run(layers['input'])
  63. generated_image = save_image(generated_image,prefix + '_' + str(epoch+1) + '.jpg')
  64. print('Loss after epoch %d: \nT: %f, \nC: %f, \nS: %f'%(epoch,epoch_loss,epoch_content_loss,epoch_style_loss))
  65. generated_image = cv2.cvtColor(generated_image, cv2.COLOR_RGB2BGR)
  66. cv2_imshow(generated_image)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement