Advertisement
Guest User

Untitled

a guest
Mar 21st, 2019
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.08 KB | None | 0 0
  1. def next_batch(batch_size):
  2. label = [0, 1, 0, 0, 0]
  3. X = []
  4. Y = []
  5. for i in range(0, batch_size):
  6. rand = random.choice(os.listdir(mnist))
  7. rand = mnist + rand
  8. img = cv2.imread(str(rand), 0)
  9. img = np.array(img)
  10. img = img.ravel()
  11. X.append(img)
  12. Y.append(label)
  13. X = np.array(X)
  14. Y = np.array(Y)
  15. return X, Y
  16.  
  17. def train(train_model=True):
  18. """
  19. Used to train the autoencoder by passing in the necessary inputs.
  20. :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid.
  21. :return: does not return anything
  22. """
  23. with tf.variable_scope(tf.get_variable_scope()):
  24. encoder_output = encoder(x_input)
  25. # Concat class label and the encoder output
  26. decoder_input = tf.concat([y_input, encoder_output], 1)
  27. decoder_output = decoder(decoder_input)
  28.  
  29. with tf.variable_scope(tf.get_variable_scope()):
  30. d_real = discriminator(real_distribution)
  31. d_fake = discriminator(encoder_output, reuse=True)
  32.  
  33. with tf.variable_scope(tf.get_variable_scope()):
  34. decoder_image = decoder(manual_decoder_input, reuse=True)
  35.  
  36. # Autoencoder loss
  37. autoencoder_loss = tf.reduce_mean(tf.square(x_target - decoder_output))
  38.  
  39. # Discriminator Loss
  40. dc_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real), logits=d_real))
  41. dc_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake), logits=d_fake))
  42. dc_loss = dc_loss_fake + dc_loss_real
  43.  
  44. # Generator loss
  45. generator_loss = tf.reduce_mean(
  46. tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake), logits=d_fake))
  47.  
  48. all_variables = tf.trainable_variables()
  49. dc_var = [var for var in all_variables if 'dc_' in var.name]
  50. en_var = [var for var in all_variables if 'e_' in var.name]
  51.  
  52. # Optimizers
  53. autoencoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
  54. beta1=beta1).minimize(autoencoder_loss)
  55. discriminator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
  56. beta1=beta1).minimize(dc_loss, var_list=dc_var)
  57. generator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
  58. beta1=beta1).minimize(generator_loss, var_list=en_var)
  59.  
  60. init = tf.global_variables_initializer()
  61.  
  62. # Reshape images to display them
  63. input_images = tf.reshape(x_input, [-1, 368, 432, 1])
  64. generated_images = tf.reshape(decoder_output, [-1, 368, 432, 1])
  65.  
  66. # Tensorboard visualization
  67. tf.summary.scalar(name='Autoencoder Loss', tensor=autoencoder_loss)
  68. tf.summary.scalar(name='Discriminator Loss', tensor=dc_loss)
  69. tf.summary.scalar(name='Generator Loss', tensor=generator_loss)
  70. tf.summary.histogram(name='Encoder Distribution', values=encoder_output)
  71. tf.summary.histogram(name='Real Distribution', values=real_distribution)
  72. tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10)
  73. tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10)
  74. summary_op = tf.summary.merge_all()
  75.  
  76. # Saving the model
  77. saver = tf.train.Saver()
  78. step = 0
  79. with tf.Session() as sess:
  80. if train_model:
  81. tensorboard_path, saved_model_path, log_path = form_results()
  82. sess.run(init)
  83. writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph)
  84. for i in range(n_epochs):
  85. # print(n_epochs)
  86. n_batches = int(10000 / batch_size)
  87. print("------------------Epoch {}/{}------------------".format(i, n_epochs))
  88. for b in range(1, n_batches+1):
  89. # print("In the loop")
  90. z_real_dist = np.random.randn(batch_size, z_dim) * 5.
  91. batch_x, batch_y = next_batch(batch_size)
  92. # print("Created the batches")
  93. sess.run(autoencoder_optimizer, feed_dict={x_input: batch_x, x_target: batch_x, y_input: batch_y})
  94. print("batch_x", batch_x)
  95. print("x_input:", x_input)
  96. print("x_target:", x_target)
  97. print("y_input:", y_input)
  98. sess.run(discriminator_optimizer,
  99. feed_dict={x_input: batch_x, x_target: batch_x, real_distribution: z_real_dist})
  100. sess.run(generator_optimizer, feed_dict={x_input: batch_x, x_target: batch_x})
  101. # print("setup the session")
  102. if b % 50 == 0:
  103. a_loss, d_loss, g_loss, summary = sess.run(
  104. [autoencoder_loss, dc_loss, generator_loss, summary_op],
  105. feed_dict={x_input: batch_x, x_target: batch_x,
  106. real_distribution: z_real_dist, y_input: batch_y})
  107. writer.add_summary(summary, global_step=step)
  108. print("Epoch: {}, iteration: {}".format(i, b))
  109. print("Autoencoder Loss: {}".format(a_loss))
  110. print("Discriminator Loss: {}".format(d_loss))
  111. print("Generator Loss: {}".format(g_loss))
  112. with open(log_path + '/log.txt', 'a') as log:
  113. log.write("Epoch: {}, iteration: {}n".format(i, b))
  114. log.write("Autoencoder Loss: {}n".format(a_loss))
  115. log.write("Discriminator Loss: {}n".format(d_loss))
  116. log.write("Generator Loss: {}n".format(g_loss))
  117. step += 1
  118.  
  119. saver.save(sess, save_path=saved_model_path, global_step=step)
  120. else:
  121. # Get the latest results folder
  122. all_results = os.listdir(results_path)
  123. all_results.sort()
  124. saver.restore(sess, save_path=tf.train.latest_checkpoint(results_path + '/' +
  125. all_results[-1] + '/Saved_models/'))
  126. generate_image_grid(sess, op=decoder_image)
  127.  
  128.  
  129. if __name__ == '__main__':
  130. parser = argparse.ArgumentParser(description="Autoencoder Train Parameter")
  131. parser.add_argument('--train', '-t', type=bool, default=True,
  132. help='Set to True to train a new model, False to load weights and display image grid')
  133. args = parser.parse_args()
  134. train(train_model=args.train)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement