Advertisement
warrior98

cnn+autoencoder

Aug 21st, 2018
143
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.40 KB | None | 0 0
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5.  
  6. tf.set_random_seed(1)
  7.  
  8. np.random.seed(1)
  9.  
  10. BATCH_SIZE = 64
  11. LR = 0.001              # learning rate
  12.  
  13. tf_x = tf.placeholder(tf.float32, [None, 28*28]) / 255.
  14. # (batch, height, width, channel)
  15. image = tf.reshape(tf_x, [-1, 28, 28, 1])
  16. tf_y = tf.placeholder(tf.int32, [None, 10])            # input y
  17.  
  18. # CNN
  19. conv1 = tf.layers.conv2d(inputs=image, filters=16, kernel_size=5, strides=1,
  20.                          padding='same', activation=tf.nn.relu)
  21. pool1 = tf.layers.max_pooling2d(conv1, pool_size=2, strides=2,)
  22. conv2 = tf.layers.conv2d(pool1, 32, 5, 1, 'same',
  23.                          activation=tf.nn.relu)    # -> (14, 14, 32)
  24. pool2 = tf.layers.max_pooling2d(conv2, 2, 2)    # -> (7, 7, 32)
  25. flat = tf.reshape(pool2, [-1, 7*7*32])          # -> (7*7*32, )
  26. output = tf.layers.dense(flat, 10)              # output layer
  27.  
  28. loss = tf.losses.softmax_cross_entropy(
  29.     onehot_labels=tf_y, logits=output)           # compute cost
  30. train_op = tf.train.AdamOptimizer(LR).minimize(loss)
  31.  
  32. accuracy = tf.metrics.accuracy(
  33.     labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]
  34.  
  35.  
  36. # Hyper Parameters
  37. BATCH_SIZE2 = 64
  38. LR2 = 0.002         # learning rate
  39. N_TEST_IMG2 = 5
  40.  
  41. # Mnist digits
  42. # use not one-hotted target data
  43. mnist = input_data.read_data_sets('./mnist', one_hot=True)
  44. test_x = mnist.test.images[:5500]
  45. test_y = mnist.test.labels[:5500]
  46.  
  47. x = np.array([])
  48. y = np.array([])
  49. for dim in range(1, 33):
  50.     print('Bottleneck size is now {}'.format(dim))
  51.  
  52.     # tf placeholder
  53.     # value in the range of (0, 1)
  54.     tf_x2 = tf.placeholder(tf.float32, [None, 28*28])
  55.  
  56.     # encoder
  57.     en0 = tf.layers.dense(tf_x2, 256, tf.nn.tanh)
  58.     en1 = tf.layers.dense(en0, 128, tf.nn.tanh)
  59.     en2 = tf.layers.dense(en1, 64, tf.nn.tanh)
  60.  
  61.     encoded = tf.layers.dense(en2, dim)
  62.  
  63.     # decoder
  64.     de0 = tf.layers.dense(encoded, 64, tf.nn.tanh)
  65.     de1 = tf.layers.dense(de0, 128, tf.nn.tanh)
  66.     de2 = tf.layers.dense(de1, 256, tf.nn.tanh)
  67.     decoded = tf.layers.dense(de2, 28*28, tf.nn.sigmoid)
  68.  
  69.     loss2 = tf.losses.mean_squared_error(labels=tf_x2, predictions=decoded)
  70.     train = tf.train.AdamOptimizer(LR2).minimize(loss2)
  71.  
  72.     sess = tf.Session()
  73.     init_op = tf.group(tf.global_variables_initializer(),
  74.                        tf.local_variables_initializer())  # the local var is for accuracy_op
  75.     sess.run(init_op)     # initialize var in graph
  76.  
  77.     for step in range(2600):
  78.         b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
  79.         _, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y})
  80.         if step % 100 == 0:
  81.             accuracy_ = sess.run(accuracy, {tf_x: test_x, tf_y: test_y})
  82.             print('Step:', step, '| train loss: %.4f' %
  83.                   loss_, '| test accuracy: %.2f' % accuracy_)
  84.  
  85.     # initialize figure
  86.     # f, a = plt.subplots(2, N_TEST_IMG2, figsize=(5, 2))
  87.     # plt.ion()   # continuously plot
  88.  
  89.     # original data (first row) for viewing
  90.     test_data = mnist.test.images
  91.     test_label = mnist.test.labels
  92.     view_data = mnist.test.images[:N_TEST_IMG2]
  93.  
  94.     # for i in range(N_TEST_IMG2):
  95.     #     a[0][i].imshow(np.reshape(view_data[i], (28, 28)), cmap='gray')
  96.     #     a[0][i].set_xticks(())
  97.     #     a[0][i].set_yticks(())
  98.  
  99.     sess.run(tf.local_variables_initializer())
  100.     for step in range(8000):
  101.         b_x, b_y = mnist.train.next_batch(BATCH_SIZE2)
  102.         _, encoded_, decoded_, loss2_ = sess.run(
  103.             [train, encoded, decoded, loss2], {tf_x2: b_x})
  104.  
  105.         if step == 0 or step % 200 == 199:     # plotting
  106.             # plotting decoded image (second row)
  107.             decoded_data = sess.run(decoded, {tf_x2: test_data})
  108.  
  109.             accuracy_ = sess.run(accuracy, feed_dict={
  110.                 tf_x: decoded_data, tf_y: test_label})
  111.  
  112.             print('train loss: %.4f' %
  113.                   loss2_, ' | test accuracy: %.4f' % accuracy_)
  114.  
  115.     #         for i in range(N_TEST_IMG2):
  116.     #             a[1][i].clear()
  117.     #             a[1][i].imshow(np.reshape(
  118.     #                 decoded_data[i], (28, 28)), cmap='gray')
  119.     #             a[1][i].set_xticks(())
  120.     #             a[1][i].set_yticks(())
  121.     #         plt.draw()
  122.     #         plt.pause(0.01)
  123.     # plt.ioff()
  124.  
  125.     x = np.append(x, dim)
  126.     y = np.append(y, accuracy_)
  127.  
  128. plt.plot(x, y)
  129. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement