Advertisement
Guest User

Untitled

a guest
Jun 20th, 2018
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.62 KB | None | 0 0
  1. def train(epoch_count, batch_size, z_dim, learning_rate, beta1, get_batches, data_shape, data_image_mode):
  2. """
  3. Train the GAN
  4. :param epoch_count: Number of epochs
  5. :param batch_size: Batch Size
  6. :param z_dim: Z dimension
  7. :param learning_rate: Learning Rate
  8. :param beta1: The exponential decay rate for the 1st moment in the optimizer
  9. :param get_batches: Function to get batches
  10. :param data_shape: Shape of the data
  11. :param data_image_mode: The image mode to use for images ("RGB" or "L")
  12. """
  13. # TODO: Build Model
  14.  
  15. #for output
  16. print_every=50
  17. show_every=100
  18. steps=0
  19.  
  20. #get image width, height and channels .. data_shape[0] is the number of images
  21. image_width = data_shape[1]
  22. image_height=data_shape[2]
  23. image_channels=data_shape[3]
  24.  
  25.  
  26. graph_real_images, graph_z_data, graph_lr = model_inputs(image_width, image_height, image_channels, z_dim)
  27. graph_dis_loss, graph_gen_loss = model_loss(graph_real_images, graph_z_data, image_channels)
  28. graph_dis_opt, graph_gen_opt = model_opt(graph_dis_loss,graph_gen_loss,graph_lr,beta1)
  29.  
  30.  
  31.  
  32.  
  33.  
  34. with tf.Session() as sess:
  35. sess.run(tf.global_variables_initializer())
  36. for epoch_i in range(epoch_count):
  37. for batch_images in get_batches(batch_size):
  38.  
  39. steps+=1
  40.  
  41. z_data = np.random.uniform(-1, 1, size=(batch_size, z_dim))
  42.  
  43.  
  44. #run sessions ------------------------
  45.  
  46. _ = sess.run(graph_dis_opt, feed_dict={graph_real_images: batch_images,
  47. graph_z_data: z_data})
  48.  
  49. _ = sess.run(graph_gen_opt, feed_dict={graph_z_data: python_z_data})
  50.  
  51. #-------------------------------------
  52.  
  53. if steps % print_every == 0:
  54. train_loss_d = graph_dis_loss.eval({graph_z_data: python_z_data, graph_real_images: batch_images})
  55. train_loss_g = graph_gen_loss.eval({graph_z_data: python_z_data})
  56.  
  57. print("Epoch {}/{}...".format(epoch_i + 1, epochs),
  58. "Discriminator Loss: {:.4f}...".format(train_loss_d),
  59. "Generator Loss: {:.4f}".format(train_loss_g))
  60.  
  61. if steps % show_every == 0:
  62. show_generator_output(sess, 25, input_z, data_shape[3], data_image_mode)
  63. # TODO: Train Model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement