Guest User

Untitled

a guest
Mar 20th, 2018
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.14 KB | None | 0 0
  1. # Setup
  2. from __future__ import print_function, division
  3. import tensorflow as tf
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. %matplotlib inline
  7.  
  8. # Load Dataset
  9. from tensorflow.examples.tutorials.mnist import input_data
  10. mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
  11.  
  12. # Reshape Data
  13. batch_image = mnist.train.next_batch(1)[0]
  14. batch_image.reshape([28, 28])
  15.  
  16. # Implement LeakyReLU
  17. def leaky_relu(x, alpha=0.01):
  18. # If x is below 0 returns alpha*x else it will return x.
  19. activation = tf.maximum(x,alpha*x)
  20. return activation
  21.  
  22. # Random Noise
  23. def sample_noise(batch_size, dim):
  24. random_noise = tf.random_uniform(maxval=1,minval=-1,shape=[batch_size, dim])
  25. return random_noise
  26.  
  27. # Discriminator
  28. def discriminator(x):
  29. with tf.variable_scope("discriminator"):
  30. fc1 = tf.layers.dense(inputs=x, units=256, activation=leaky_relu)
  31. fc2 = tf.layers.dense(inputs=fc1, units=256, activation=leaky_relu)
  32. logits = tf.layers.dense(inputs=fc2, units=1)
  33. return logits
  34.  
  35. # Generator
  36. def generator(z):
  37. with tf.variable_scope("generator"):
  38. fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
  39. fc2 = tf.layers.dense(inputs=fc1, units=1024, activation=tf.nn.relu)
  40. img = tf.layers.dense(inputs=fc2, units=784, activation=tf.nn.tanh)
  41. return img
  42.  
  43. # Compute GAN Loss
  44. def gan_loss(logits_real, logits_fake):
  45. # Target label vector for generator loss and used in discriminator loss.
  46. true_labels = tf.ones_like(logits_fake)
  47.  
  48. # DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
  49. # classifies fake images.
  50. real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
  51. fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)
  52.  
  53. # Combine and average losses over the batch
  54. D_loss = real_image_loss + fake_image_loss
  55. D_loss = tf.reduce_mean(D_loss)
  56.  
  57. # GENERATOR is trying to make the discriminator output 1 for all its images.
  58. # So we use our target label vector of ones for computing generator loss.
  59. G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)
  60.  
  61. # Average generator loss over the batch.
  62. G_loss = tf.reduce_mean(G_loss)
  63.  
  64. return D_loss, G_loss
  65.  
  66. # Optimizing GAN Loss
  67. def get_solvers(learning_rate=1e-3, beta1=0.5):
  68. # Create solvers for GAN training
  69. D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
  70. G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
  71. return D_solver, G_solver
  72.  
  73. ##############################################################################
  74. # Final Model
  75.  
  76. # Number of Images for Each Batch
  77. batch_size = 128
  78. # Noise Dimension
  79. noise_dim = 96
  80.  
  81. # Placeholder for Images from the Training Dataset
  82. x = tf.placeholder(tf.float32, [None, 784])
  83. # Random Noise for the Generator
  84. z = sample_noise(batch_size, noise_dim)
  85. # Generated Images
  86. G_sample = generator(z)
  87.  
  88. with tf.variable_scope("") as scope:
  89. # Scale Images to be -1 to 1
  90. logits_real = discriminator(preprocess_img(x))
  91. # Re-use Discriminator Weights on New Inputs
  92. scope.reuse_variables()
  93. logits_fake = discriminator(G_sample)
  94.  
  95. # Get the List of Variables for the Discriminator and Generator
  96. D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
  97. G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
  98.  
  99. # Get the Solver
  100. D_solver, G_solver = get_solvers()
  101.  
  102. # Get the Loss
  103. D_loss, G_loss = gan_loss(logits_real, logits_fake)
  104.  
  105. # Setup Training Steps
  106. D_train_step = D_solver.minimize(D_loss, var_list=D_vars)
  107. G_train_step = G_solver.minimize(G_loss, var_list=G_vars)
  108. D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator')
  109. G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator')
  110. ##############################################################################
  111.  
  112. # Training a GAN
  113. def training_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\
  114. show_every=250, print_every=50, batch_size=128, num_epoch=10):
  115.  
  116. # Compute the Number of Iterations Needed
  117. max_iter = int(mnist.train.num_examples*num_epoch/batch_size)
  118. for it in range(max_iter):
  119.  
  120. # For Every 250 Images, Show A Sample Result
  121. if it % show_every == 0:
  122. samples = sess.run(G_sample)
  123. fig = show_images(samples[:16])
  124. plt.show()
  125. print()
  126.  
  127. # Run a Batch of Data
  128. minibatch,minbatch_y = mnist.train.next_batch(batch_size)
  129. _, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch})
  130. _, G_loss_curr = sess.run([G_train_step, G_loss])
  131.  
  132. # For Every 50 Iterations, Print Loss
  133. if it % print_every == 0:
  134. print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr))
  135.  
  136. print('Final images')
  137. samples = sess.run(G_sample)
  138.  
  139. fig = show_images(samples[:16])
  140. plt.show()
  141.  
  142. # Run the helper function
  143. with get_session() as sess:
  144. sess.run(tf.global_variables_initializer())
  145. training_gan(sess,G_train_step,G_loss,D_train_step,D_loss,G_extra_step,D_extra_step)
Add Comment
Please, Sign In to add comment