Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train(epoch_count, batch_size, z_dim, learning_rate, beta1, get_batches, data_shape, data_image_mode):
- """
- Train the GAN
- :param epoch_count: Number of epochs
- :param batch_size: Batch Size
- :param z_dim: Z dimension
- :param learning_rate: Learning Rate
- :param beta1: The exponential decay rate for the 1st moment in the optimizer
- :param get_batches: Function to get batches
- :param data_shape: Shape of the data
- :param data_image_mode: The image mode to use for images ("RGB" or "L")
- """
- # TODO: Build Model
- #for output
- print_every=50
- show_every=100
- steps=0
- #get image width, height and channels .. data_shape[0] is the number of images
- image_width = data_shape[1]
- image_height=data_shape[2]
- image_channels=data_shape[3]
- graph_real_images, graph_z_data, graph_lr = model_inputs(image_width, image_height, image_channels, z_dim)
- graph_dis_loss, graph_gen_loss = model_loss(graph_real_images, graph_z_data, image_channels)
- graph_dis_opt, graph_gen_opt = model_opt(graph_dis_loss,graph_gen_loss,graph_lr,beta1)
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- for epoch_i in range(epoch_count):
- for batch_images in get_batches(batch_size):
- steps+=1
- z_data = np.random.uniform(-1, 1, size=(batch_size, z_dim))
- #run sessions ------------------------
- _ = sess.run(graph_dis_opt, feed_dict={graph_real_images: batch_images,
- graph_z_data: z_data})
- _ = sess.run(graph_gen_opt, feed_dict={graph_z_data: python_z_data})
- #-------------------------------------
- if steps % print_every == 0:
- train_loss_d = graph_dis_loss.eval({graph_z_data: python_z_data, graph_real_images: batch_images})
- train_loss_g = graph_gen_loss.eval({graph_z_data: python_z_data})
- print("Epoch {}/{}...".format(epoch_i + 1, epochs),
- "Discriminator Loss: {:.4f}...".format(train_loss_d),
- "Generator Loss: {:.4f}".format(train_loss_g))
- if steps % show_every == 0:
- show_generator_output(sess, 25, input_z, data_shape[3], data_image_mode)
- # TODO: Train Model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement