Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def next_batch(batch_size):
- label = [0, 1, 0, 0, 0]
- X = []
- Y = []
- for i in range(0, batch_size):
- rand = random.choice(os.listdir(mnist))
- rand = mnist + rand
- img = cv2.imread(str(rand), 0)
- img = np.array(img)
- img = img.ravel()
- X.append(img)
- Y.append(label)
- X = np.array(X)
- Y = np.array(Y)
- return X, Y
- def train(train_model=True):
- """
- Used to train the autoencoder by passing in the necessary inputs.
- :param train_model: True -> Train the model, False -> Load the latest trained model and show the image grid.
- :return: does not return anything
- """
- with tf.variable_scope(tf.get_variable_scope()):
- encoder_output = encoder(x_input)
- # Concat class label and the encoder output
- decoder_input = tf.concat([y_input, encoder_output], 1)
- decoder_output = decoder(decoder_input)
- with tf.variable_scope(tf.get_variable_scope()):
- d_real = discriminator(real_distribution)
- d_fake = discriminator(encoder_output, reuse=True)
- with tf.variable_scope(tf.get_variable_scope()):
- decoder_image = decoder(manual_decoder_input, reuse=True)
- # Autoencoder loss
- autoencoder_loss = tf.reduce_mean(tf.square(x_target - decoder_output))
- # Discriminator Loss
- dc_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real), logits=d_real))
- dc_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake), logits=d_fake))
- dc_loss = dc_loss_fake + dc_loss_real
- # Generator loss
- generator_loss = tf.reduce_mean(
- tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake), logits=d_fake))
- all_variables = tf.trainable_variables()
- dc_var = [var for var in all_variables if 'dc_' in var.name]
- en_var = [var for var in all_variables if 'e_' in var.name]
- # Optimizers
- autoencoder_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
- beta1=beta1).minimize(autoencoder_loss)
- discriminator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
- beta1=beta1).minimize(dc_loss, var_list=dc_var)
- generator_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
- beta1=beta1).minimize(generator_loss, var_list=en_var)
- init = tf.global_variables_initializer()
- # Reshape images to display them
- input_images = tf.reshape(x_input, [-1, 368, 432, 1])
- generated_images = tf.reshape(decoder_output, [-1, 368, 432, 1])
- # Tensorboard visualization
- tf.summary.scalar(name='Autoencoder Loss', tensor=autoencoder_loss)
- tf.summary.scalar(name='Discriminator Loss', tensor=dc_loss)
- tf.summary.scalar(name='Generator Loss', tensor=generator_loss)
- tf.summary.histogram(name='Encoder Distribution', values=encoder_output)
- tf.summary.histogram(name='Real Distribution', values=real_distribution)
- tf.summary.image(name='Input Images', tensor=input_images, max_outputs=10)
- tf.summary.image(name='Generated Images', tensor=generated_images, max_outputs=10)
- summary_op = tf.summary.merge_all()
- # Saving the model
- saver = tf.train.Saver()
- step = 0
- with tf.Session() as sess:
- if train_model:
- tensorboard_path, saved_model_path, log_path = form_results()
- sess.run(init)
- writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=sess.graph)
- for i in range(n_epochs):
- # print(n_epochs)
- n_batches = int(10000 / batch_size)
- print("------------------Epoch {}/{}------------------".format(i, n_epochs))
- for b in range(1, n_batches+1):
- # print("In the loop")
- z_real_dist = np.random.randn(batch_size, z_dim) * 5.
- batch_x, batch_y = next_batch(batch_size)
- # print("Created the batches")
- sess.run(autoencoder_optimizer, feed_dict={x_input: batch_x, x_target: batch_x, y_input: batch_y})
- print("batch_x", batch_x)
- print("x_input:", x_input)
- print("x_target:", x_target)
- print("y_input:", y_input)
- sess.run(discriminator_optimizer,
- feed_dict={x_input: batch_x, x_target: batch_x, real_distribution: z_real_dist})
- sess.run(generator_optimizer, feed_dict={x_input: batch_x, x_target: batch_x})
- # print("setup the session")
- if b % 50 == 0:
- a_loss, d_loss, g_loss, summary = sess.run(
- [autoencoder_loss, dc_loss, generator_loss, summary_op],
- feed_dict={x_input: batch_x, x_target: batch_x,
- real_distribution: z_real_dist, y_input: batch_y})
- writer.add_summary(summary, global_step=step)
- print("Epoch: {}, iteration: {}".format(i, b))
- print("Autoencoder Loss: {}".format(a_loss))
- print("Discriminator Loss: {}".format(d_loss))
- print("Generator Loss: {}".format(g_loss))
- with open(log_path + '/log.txt', 'a') as log:
- log.write("Epoch: {}, iteration: {}n".format(i, b))
- log.write("Autoencoder Loss: {}n".format(a_loss))
- log.write("Discriminator Loss: {}n".format(d_loss))
- log.write("Generator Loss: {}n".format(g_loss))
- step += 1
- saver.save(sess, save_path=saved_model_path, global_step=step)
- else:
- # Get the latest results folder
- all_results = os.listdir(results_path)
- all_results.sort()
- saver.restore(sess, save_path=tf.train.latest_checkpoint(results_path + '/' +
- all_results[-1] + '/Saved_models/'))
- generate_image_grid(sess, op=decoder_image)
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description="Autoencoder Train Parameter")
- parser.add_argument('--train', '-t', type=bool, default=True,
- help='Set to True to train a new model, False to load weights and display image grid')
- args = parser.parse_args()
- train(train_model=args.train)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement