Advertisement
Guest User

Untitled

a guest
Feb 17th, 2018
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.71 KB | None | 0 0
  1. """ Deep Convolutional Generative Adversarial Network (DCGAN).
  2. Using deep convolutional generative adversarial networks (DCGAN) to generate
  3. digit images from a noise distribution.
  4. References:
  5.    - Unsupervised representation learning with deep convolutional generative
  6.    adversarial networks. A Radford, L Metz, S Chintala. arXiv:1511.06434.
  7. Links:
  8.    - [DCGAN Paper](https://arxiv.org/abs/1511.06434).
  9.    - [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).
  10. Author: Aymeric Damien
  11. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  12. """
  13.  
  14. from __future__ import division, print_function, absolute_import
  15.  
  16. import matplotlib.pyplot as plt
  17. import numpy as np
  18. import tensorflow as tf
  19.  
  20. # Import MNIST data
  21. from tensorflow.examples.tutorials.mnist import input_data
  22. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  23.  
  24. # Training Params
  25. num_steps = 20000
  26. batch_size = 32
  27.  
  28. # Network Params
  29. image_dim = 784 # 28*28 pixels * 1 channel
  30. gen_hidden_dim = 256
  31. disc_hidden_dim = 256
  32. noise_dim = 784 # Noise data points.  What does this even mean. shouldn't this have the same dimensions as the real images? why does this default to 200? instead of 28*28??
  33.  
  34.  
  35. # Generator Network
  36. # Input: Noise, Output: Image
  37. def generator(x, reuse=False):
  38.     with tf.variable_scope('Generator', reuse=reuse):
  39.         # TensorFlow Layers automatically create variables and calculate their
  40.         # shape, based on the input.
  41.         x = tf.layers.dense(x, units=6 * 6 * 128)
  42.         x = tf.nn.tanh(x)
  43.         # Reshape to a 4-D array of images: (batch, height, width, channels)
  44.         # New shape: (batch, 6, 6, 128)
  45.         x = tf.reshape(x, shape=[-1, 6, 6, 128])
  46.         # Deconvolution, image shape: (batch, 14, 14, 64)
  47.         x = tf.layers.conv2d_transpose(x, 64, 4, strides=2)
  48.         # Deconvolution, image shape: (batch, 28, 28, 1)
  49.         x = tf.layers.conv2d_transpose(x, 1, 2, strides=2)
  50.         # Apply sigmoid to clip values between 0 and 1
  51.         x = tf.nn.sigmoid(x)
  52.         return x
  53.  
  54.  
  55. # Discriminator Network
  56. # Input: Image, Output: Prediction Real/Fake Image
  57. def discriminator(x, reuse=False):
  58.     with tf.variable_scope('Discriminator', reuse=reuse):
  59.         # Typical convolutional neural network to classify images.
  60.         x = tf.layers.conv2d(x, 64, 5)
  61.         x = tf.nn.tanh(x)
  62.         x = tf.layers.average_pooling2d(x, 2, 2)
  63.         x = tf.layers.conv2d(x, 128, 5)
  64.         x = tf.nn.tanh(x)
  65.         x = tf.layers.average_pooling2d(x, 2, 2)
  66.         x = tf.contrib.layers.flatten(x)
  67.         x = tf.layers.dense(x, 1024)
  68.         x = tf.nn.tanh(x)
  69.         # Output 2 classes: Real and Fake images
  70.         x = tf.layers.dense(x, 2)
  71.     return x
  72.  
  73. # Build Networks
  74. # Network Inputs
  75. noise_input = tf.placeholder(tf.float32, shape=[None, noise_dim])
  76. real_image_input = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
  77.  
  78. # Build Generator Network
  79. gen_sample = generator(noise_input)
  80.  
  81. # Build 2 Discriminator Networks (one from noise input, one from generated samples)
  82. disc_real = discriminator(real_image_input)
  83. disc_fake = discriminator(gen_sample, reuse=True)
  84. disc_concat = tf.concat([disc_real, disc_fake], axis=0)
  85.  
  86. # Build the stacked generator/discriminator
  87. stacked_gan = discriminator(gen_sample, reuse=True)
  88.  
  89. # Build Targets (real or fake images)
  90. disc_target = tf.placeholder(tf.int32, shape=[None]) #None represents this dimension can be ANY size [NONE, 2] could be [0,2], or [1,2], or [1000,2]
  91. gen_target = tf.placeholder(tf.int32, shape=[None]) #we use none so we can use the first dimension for batching, and different sizes of batches
  92.  
  93. # Build Loss
  94. disc_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
  95.     logits=disc_concat, labels=disc_target))
  96. gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
  97.     logits=stacked_gan, labels=gen_target))
  98.  
  99. # Build Optimizers
  100. optimizer_gen = tf.train.AdamOptimizer(learning_rate=0.001)
  101. optimizer_disc = tf.train.AdamOptimizer(learning_rate=0.001)
  102.  
  103. # Training Variables for each optimizer
  104. # By default in TensorFlow, all variables are updated by each optimizer, so we
  105. # need to precise for each one of them the specific variables to update.
  106. # Generator Network Variables
  107. gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator')
  108. # Discriminator Network Variables
  109. disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator')
  110.  
  111. # Create training operations
  112. train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
  113. train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)
  114.  
  115. # Initialize the variables (i.e. assign their default value)
  116. init = tf.global_variables_initializer()
  117. saver = tf.train.Saver()
  118. configNum = 4
  119. imgNum = 4
  120. # Start training
  121. with tf.Session() as sess:
  122.  
  123.     # Run the initializer
  124.     sess.run(init)
  125.     try:
  126.         saver.restore(sess, "./model"+str(configNum)+".ckpt")
  127.         print("Model restored")
  128.     except:
  129.         print("Model failed to be restored")
  130.  
  131.     #test generator noise
  132.     tz = [np.random.uniform(-1., 1., size=[4,noise_dim]) for x in range(10)]
  133.     for i in range(1, num_steps+1):
  134.  
  135.         # Prepare Input Data
  136.         # Get the next batch of MNIST data (only images are needed, not labels)
  137.         batch_x, _ = mnist.train.next_batch(batch_size)
  138.         batch_x = np.reshape(batch_x, newshape=[-1, 28, 28, 1])
  139.         # Generate noise to feed to the generator
  140.         z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])
  141.  
  142.         # Prepare Targets (Real image: 1, Fake image: 0)
  143.         # The first half of data fed to the generator are real images,
  144.         # the other half are fake images (coming from the generator).
  145.         batch_disc_y = np.concatenate(
  146.             [np.ones([batch_size]), np.zeros([batch_size])], axis=0)
  147.         # Generator tries to fool the discriminator, thus targets are 1.
  148.         batch_gen_y = np.ones([batch_size])
  149.  
  150.         # Training
  151.         feed_dict = {real_image_input: batch_x, noise_input: z,
  152.                      disc_target: batch_disc_y, gen_target: batch_gen_y}
  153.         _, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],
  154.                                 feed_dict=feed_dict)
  155.  
  156.         # Save after every step?
  157.        
  158.        
  159.         if i % 100 == 0 or i == 1:
  160.             print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))
  161.             f, a = plt.subplots(4, 10, figsize=(10, 4))
  162.             for k in range(10): # 10 columns
  163.  
  164.                 # Noise input.
  165.                 z = tz[k]
  166.  
  167.                 g = sess.run(gen_sample, feed_dict={noise_input: z})
  168.                 for j in range(4):
  169.                     # Generate image from noise. Extend to 3 channels for matplot figure.
  170.                     img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),
  171.                                      newshape=(28, 28, 3))
  172.                     a[j][k].imshow(img)
  173.                 plt.savefig('./results/'+str(imgNum)+'dcganimg'+str(i)+'.png')
  174.  
  175.     save_path = saver.save(sess, "./model"+str(configNum)+".ckpt")
  176.     print("Model saved in path: %s" % save_path)
  177.     # Generate images from noise, using the generator network.
  178.     f, a = plt.subplots(4, 10, figsize=(10, 4))
  179.     for i in range(10): # 10 columns
  180.         # Noise input.
  181.         z = tz[i] #4 batches at a time
  182.         g = sess.run(gen_sample, feed_dict={noise_input: z})
  183.         for j in range(4):
  184.             # Generate image from noise. Extend to 3 channels for matplot figure.
  185.             img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),
  186.                              newshape=(28, 28, 3))
  187.             a[j][i].imshow(img)
  188.  
  189.     f.show()
  190.     plt.savefig('./results/dcganimgdone.png')
  191.     plt.draw()
  192.     plt.waitforbuttonpress()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement