Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Setup
- from __future__ import print_function, division
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- %matplotlib inline
- # Load Dataset
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
- # Reshape Data
- batch_image = mnist.train.next_batch(1)[0]
- batch_image.reshape([28, 28])
- # Implement LeakyReLU
- def leaky_relu(x, alpha=0.01):
- # If x is below 0 returns alpha*x else it will return x.
- activation = tf.maximum(x,alpha*x)
- return activation
- # Random Noise
- def sample_noise(batch_size, dim):
- random_noise = tf.random_uniform(maxval=1,minval=-1,shape=[batch_size, dim])
- return random_noise
- # Discriminator
- def discriminator(x):
- with tf.variable_scope("discriminator"):
- fc1 = tf.layers.dense(inputs=x, units=256, activation=leaky_relu)
- fc2 = tf.layers.dense(inputs=fc1, units=256, activation=leaky_relu)
- logits = tf.layers.dense(inputs=fc2, units=1)
- return logits
- # Generator
- def generator(z):
- with tf.variable_scope("generator"):
- fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
- fc2 = tf.layers.dense(inputs=fc1, units=1024, activation=tf.nn.relu)
- img = tf.layers.dense(inputs=fc2, units=784, activation=tf.nn.tanh)
- return img
- # Compute GAN Loss
- def gan_loss(logits_real, logits_fake):
- # Target label vector for generator loss and used in discriminator loss.
- true_labels = tf.ones_like(logits_fake)
- # DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
- # classifies fake images.
- real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
- fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)
- # Combine and average losses over the batch
- D_loss = real_image_loss + fake_image_loss
- D_loss = tf.reduce_mean(D_loss)
- # GENERATOR is trying to make the discriminator output 1 for all its images.
- # So we use our target label vector of ones for computing generator loss.
- G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)
- # Average generator loss over the batch.
- G_loss = tf.reduce_mean(G_loss)
- return D_loss, G_loss
- # Optimizing GAN Loss
- def get_solvers(learning_rate=1e-3, beta1=0.5):
- # Create solvers for GAN training
- D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
- G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
- return D_solver, G_solver
- ##############################################################################
- # Final Model
- # Number of Images for Each Batch
- batch_size = 128
- # Noise Dimension
- noise_dim = 96
- # Placeholder for Images from the Training Dataset
- x = tf.placeholder(tf.float32, [None, 784])
- # Random Noise for the Generator
- z = sample_noise(batch_size, noise_dim)
- # Generated Images
- G_sample = generator(z)
- with tf.variable_scope("") as scope:
- # Scale Images to be -1 to 1
- logits_real = discriminator(preprocess_img(x))
- # Re-use Discriminator Weights on New Inputs
- scope.reuse_variables()
- logits_fake = discriminator(G_sample)
- # Get the List of Variables for the Discriminator and Generator
- D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
- G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
- # Get the Solver
- D_solver, G_solver = get_solvers()
- # Get the Loss
- D_loss, G_loss = gan_loss(logits_real, logits_fake)
- # Setup Training Steps
- D_train_step = D_solver.minimize(D_loss, var_list=D_vars)
- G_train_step = G_solver.minimize(G_loss, var_list=G_vars)
- D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator')
- G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator')
- ##############################################################################
- # Training a GAN
- def training_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\
- show_every=250, print_every=50, batch_size=128, num_epoch=10):
- # Compute the Number of Iterations Needed
- max_iter = int(mnist.train.num_examples*num_epoch/batch_size)
- for it in range(max_iter):
- # For Every 250 Images, Show A Sample Result
- if it % show_every == 0:
- samples = sess.run(G_sample)
- fig = show_images(samples[:16])
- plt.show()
- print()
- # Run a Batch of Data
- minibatch,minbatch_y = mnist.train.next_batch(batch_size)
- _, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch})
- _, G_loss_curr = sess.run([G_train_step, G_loss])
- # For Every 50 Iterations, Print Loss
- if it % print_every == 0:
- print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr))
- print('Final images')
- samples = sess.run(G_sample)
- fig = show_images(samples[:16])
- plt.show()
- # Run the helper function
- with get_session() as sess:
- sess.run(tf.global_variables_initializer())
- training_gan(sess,G_train_step,G_loss,D_train_step,D_loss,G_extra_step,D_extra_step)
Add Comment
Please, Sign In to add comment