SHARE
TWEET

Untitled

a guest Jun 15th, 2019 58 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def run_style_transfer(content_image_filename,
  2.                       style_image_filename,
  3.                       initialImage=None,
  4.                       alpha=10,beta=1e-1,
  5.                       epochs=1200,
  6.                       learning_rate=5.0,
  7.                       prefix=None):
  8.  
  9.   if prefix is None:
  10.     print('Output file prefix not defined...')
  11.     return
  12.  
  13.   content_image = load_image(content_image_filename)
  14.   style_image = load_image(style_image_filename)
  15.  
  16.   weights_dict = getModelWeightsAsDict()
  17.   tf.reset_default_graph()
  18.   layers = get_nst_model(weights_dict)
  19.  
  20.   print('Model graph generated...\n')
  21.  
  22.   # Calculate tensor for content loss
  23.   J_content = 0.0
  24.   with tf.Session() as sess:
  25.     for content_layer_name, weight in CONTENT_LAYERS:
  26.       content_layer = layers[content_layer_name]
  27.       content_target = sess.run(content_layer,feed_dict={layers['input']:content_image})
  28.       J_content = J_content + weight * getContentLoss(layers,content_target, content_layer_name)
  29.   print('Content loss defined...\n')
  30.  
  31.  
  32.   # Calculate tensor for style loss
  33.   J_style = 0.0
  34.   with tf.Session() as sess:
  35.     for style_layer_name, weight in STYLE_LAYERS:
  36.         style_layer = layers[style_layer_name]
  37.         style_target = sess.run(style_layer,feed_dict={layers['input']:style_image})
  38.         J_style = J_style + weight * getStyleLoss(layers,style_target,style_layer_name)
  39.   print('Style loss defined...\n')
  40.  
  41.   J_total = getTotalLoss(J_content,J_style,alpha=alpha,beta=beta)
  42.  
  43.   optimizer = tf.train.AdamOptimizer(learning_rate).minimize(J_total)
  44.  
  45.   print('Losses defined...\n')
  46.  
  47.   print('Trainable variables:\n',tf.trainable_variables())
  48.  
  49.   with tf.Session() as sess:
  50.    
  51.     # Initialize image
  52.     sess.run(tf.global_variables_initializer())
  53.     if initialImage is not None:
  54.       sess.run(layers['input'].assign(initialImage))
  55.     else:
  56.       sess.run(layers['input'].assign(generate_noise_image()))
  57.      
  58.     for epoch in range(epochs):
  59.       epoch_loss, epoch_content_loss, epoch_style_loss, _ = sess.run([J_total,J_content,J_style,optimizer])
  60.  
  61.       if (epoch+1) % 100 == 0:
  62.         generated_image = sess.run(layers['input'])
  63.         generated_image = save_image(generated_image,prefix + '_' + str(epoch+1) + '.jpg')
  64.         print('Loss after epoch %d: \nT: %f, \nC: %f, \nS: %f'%(epoch,epoch_loss,epoch_content_loss,epoch_style_loss))
  65.         generated_image = cv2.cvtColor(generated_image, cv2.COLOR_RGB2BGR)
  66.         cv2_imshow(generated_image)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top