warrior98

autoencoder

Aug 15th, 2018
104
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.19 KB | None | 0 0
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5.  
  6. tf.set_random_seed(1)
  7.  
  8. # Hyper Parameters
  9. BATCH_SIZE = 64
  10. LR = 0.002         # learning rate
  11. N_TEST_IMG = 5
  12.  
  13. mnist = input_data.read_data_sets('./mnist', one_hot=False)
  14. test_x = mnist.test.images[:200]
  15. test_y = mnist.test.labels[:200]
  16.  
  17. print(mnist.train.images.shape)     # (55000, 28 * 28)
  18. print(mnist.test.labels.shape)     # (55000, 10)
  19.  
  20. tf_x = tf.placeholder(tf.float32, [None, 28*28])
  21.  
  22. # encoder
  23. en0 = tf.layers.dense(tf_x, 256, tf.nn.tanh)
  24. en1 = tf.layers.dense(en0, 128, tf.nn.tanh)
  25. en2 = tf.layers.dense(en1, 64, tf.nn.tanh)
  26. encoded = tf.layers.dense(en2, 32)
  27.  
  28. # decoder
  29. de0 = tf.layers.dense(encoded, 64, tf.nn.tanh)
  30. de1 = tf.layers.dense(de0, 128, tf.nn.tanh)
  31. de2 = tf.layers.dense(de1, 256, tf.nn.tanh)
  32. decoded = tf.layers.dense(de2, 28*28, tf.nn.sigmoid)
  33.  
  34. loss = tf.losses.mean_squared_error(labels=tf_x, predictions=decoded)
  35. train = tf.train.AdamOptimizer(LR).minimize(loss)
  36.  
  37. sess = tf.Session()
  38. sess.run(tf.global_variables_initializer())
  39.  
  40. # initialize figure
  41. f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
  42. plt.ion()   # continuously plot
  43.  
  44. # original data (first row) for viewing
  45. view_data = mnist.test.images[:N_TEST_IMG]
  46. for i in range(N_TEST_IMG):
  47.     a[0][i].imshow(np.reshape(view_data[i], (28, 28)), cmap='gray')
  48.     a[0][i].set_xticks(())
  49.     a[0][i].set_yticks(())
  50.  
  51. for step in range(5000):
  52.     b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
  53.     _, encoded_, decoded_, loss_ = sess.run(
  54.         [train, encoded, decoded, loss], {tf_x: b_x})
  55.  
  56.     if step % 100 == 0:     # plotting
  57.         print('train loss: %.4f' % loss_)
  58.         # plotting decoded image (second row)
  59.         decoded_data = sess.run(decoded, {tf_x: view_data})
  60.         for i in range(N_TEST_IMG):
  61.             a[1][i].clear()
  62.             a[1][i].imshow(np.reshape(decoded_data[i], (28, 28)), cmap='gray')
  63.             a[1][i].set_xticks(())
  64.             a[1][i].set_yticks(())
  65.         plt.draw()
  66.         plt.pause(0.01)
  67. plt.ioff()
  68.  
  69. for index in range(0, 0):
  70.     decoded_data = decoded
  71.     decoded_data_unnormalized = tf.round(decoded_data*255)
  72.     decoded_data_uint8 = tf.cast(decoded_data_unnormalized, tf.uint8)
  73.     decoded_data_uint8_reshaped = tf.reshape(decoded_data_uint8, [28, 28, 1])
  74.     aux = tf.image.encode_jpeg(decoded_data_uint8_reshaped)
  75.  
  76.     writer = tf.write_file(
  77.         'train\\'+str(mnist.test.labels[index])+'\\'+str(index)+'.jpg', aux)
  78.  
  79.     if index % 10 == 0:
  80.         print("Image number: {}".format(index))
  81.  
  82.     sess.run(writer, {tf_x: [mnist.train.images[index]]})
  83.  
  84. for index in range(0, 500):
  85.     decoded_data = decoded
  86.     decoded_data_unnormalized = tf.round(decoded_data*255)
  87.     decoded_data_uint8 = tf.cast(decoded_data_unnormalized, tf.uint8)
  88.     decoded_data_uint8_reshaped = tf.reshape(decoded_data_uint8, [28, 28, 1])
  89.     aux = tf.image.encode_jpeg(decoded_data_uint8_reshaped)
  90.  
  91.     writer = tf.write_file(
  92.         'test\\'+str(mnist.test.labels[index])+'\\'+str(index)+'.jpg', aux)
  93.     if index % 10 == 0:
  94.         print("Image number: {}".format(index))
  95.  
  96.     sess.run(writer, {tf_x: [mnist.test.images[index]]})
Add Comment
Please, Sign In to add comment