Advertisement
Guest User

Untitled

a guest
Jan 17th, 2019
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.92 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4.  
  5. from tensorflow.examples.tutorials.mnist import input_data
  6. mnist = input_data.read_data_sets('MNIST_data')
  7.  
  8. print("---------------------------")
  9. print(mnist)
  10. print("---------------------------")
  11.  
  12. tf.reset_default_graph()
  13.  
  14. batch_size = 64
  15.  
  16. X_in = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='X')
  17. Y = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='Y')
  18. Y_flat = tf.reshape(Y, shape=[-1, 28 * 28])
  19. keep_prob = tf.placeholder(dtype=tf.float32, shape=(), name='keep_prob')
  20.  
  21. dec_in_channels = 1
  22. n_latent = 8
  23.  
  24. reshaped_dim = [-1, 7, 7, dec_in_channels]
  25. inputs_decoder = 49 * dec_in_channels / 2
  26.  
  27.  
  28. def lrelu(x, alpha=0.3):
  29. return tf.maximum(x, tf.multiply(x, alpha))
  30.  
  31.  
  32. def encoder(X_in, keep_prob):
  33. activation = lrelu
  34. with tf.variable_scope("encoder", reuse=None):
  35. X = tf.reshape(X_in, shape=[-1, 28, 28, 1])
  36. x = tf.layers.conv2d(X, filters=64, kernel_size=4, strides=2, padding='same', activation=activation)
  37. x = tf.nn.dropout(x, keep_prob)
  38. x = tf.layers.conv2d(x, filters=64, kernel_size=4, strides=2, padding='same', activation=activation)
  39. x = tf.nn.dropout(x, keep_prob)
  40. x = tf.layers.conv2d(x, filters=64, kernel_size=4, strides=1, padding='same', activation=activation)
  41. x = tf.nn.dropout(x, keep_prob)
  42. x = tf.contrib.layers.flatten(x)
  43. mn = tf.layers.dense(x, units=n_latent)
  44. sd = 0.5 * tf.layers.dense(x, units=n_latent)
  45. epsilon = tf.random_normal(tf.stack([tf.shape(x)[0], n_latent]))
  46. z = mn + tf.multiply(epsilon, tf.exp(sd))
  47.  
  48. return z, mn, sd
  49.  
  50.  
  51. def decoder(sampled_z, keep_prob):
  52. with tf.variable_scope("decoder", reuse=None):
  53. x = tf.layers.dense(sampled_z, units=inputs_decoder, activation=lrelu)
  54. x = tf.layers.dense(x, units=inputs_decoder * 2 + 1, activation=lrelu)
  55. x = tf.reshape(x, reshaped_dim)
  56. x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=2, padding='same', activation=tf.nn.relu)
  57. x = tf.nn.dropout(x, keep_prob)
  58. x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=1, padding='same', activation=tf.nn.relu)
  59. x = tf.nn.dropout(x, keep_prob)
  60. x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=1, padding='same', activation=tf.nn.relu)
  61.  
  62. x = tf.contrib.layers.flatten(x)
  63. x = tf.layers.dense(x, units=28 * 28, activation=tf.nn.sigmoid)
  64. img = tf.reshape(x, shape=[-1, 28, 28])
  65. return img
  66.  
  67. sampled, mn, sd = encoder(X_in, keep_prob)
  68. dec = decoder(sampled, keep_prob)
  69.  
  70. unreshaped = tf.reshape(dec, [-1, 28*28])
  71. img_loss = tf.reduce_sum(tf.squared_difference(unreshaped, Y_flat), 1)
  72. latent_loss = -0.5 * tf.reduce_sum(1.0 + 2.0 * sd - tf.square(mn) - tf.exp(2.0 * sd), 1)
  73. loss = tf.reduce_mean(img_loss + latent_loss)
  74. optimizer = tf.train.AdamOptimizer(0.0005).minimize(loss)
  75. sess = tf.Session()
  76. sess.run(tf.global_variables_initializer())
  77.  
  78. for i in range(30000):
  79. batch = [np.reshape(b, [28, 28]) for b in mnist.train.next_batch(batch_size=batch_size)[0]]
  80. sess.run(optimizer, feed_dict={X_in: batch, Y: batch, keep_prob: 0.8})
  81.  
  82. if not i % 200:
  83. ls, d, i_ls, d_ls, mu, sigm = sess.run([loss, dec, img_loss, latent_loss, mn, sd],
  84. feed_dict={X_in: batch, Y: batch, keep_prob: 1.0})
  85. plt.imshow(np.reshape(batch[0], [28, 28]), cmap='gray')
  86. plt.show()
  87. plt.imshow(d[0], cmap='gray')
  88. plt.show()
  89. print(i, ls, np.mean(i_ls), np.mean(d_ls))
  90.  
  91.  
  92. randoms = [np.random.normal(0, 1, n_latent) for _ in range(10)]
  93. imgs = sess.run(dec, feed_dict = {sampled: randoms, keep_prob: 1.0})
  94. imgs = [np.reshape(imgs[i], [28, 28]) for i in range(len(imgs))]
  95.  
  96. for img in imgs:
  97. plt.figure(figsize=(1,1))
  98. plt.axis('off')
  99. plt.imshow(img, cmap='gray')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement