Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- import matplotlib.pyplot as plt
- import numpy as np
- tf.set_random_seed(1)
- # Hyper Parameters
- BATCH_SIZE = 64
- LR = 0.002 # learning rate
- N_TEST_IMG = 5
- mnist = input_data.read_data_sets('./mnist', one_hot=False)
- test_x = mnist.test.images[:200]
- test_y = mnist.test.labels[:200]
- print(mnist.train.images.shape) # (55000, 28 * 28)
- print(mnist.test.labels.shape) # (55000, 10)
- tf_x = tf.placeholder(tf.float32, [None, 28*28])
- # encoder
- en0 = tf.layers.dense(tf_x, 256, tf.nn.tanh)
- en1 = tf.layers.dense(en0, 128, tf.nn.tanh)
- en2 = tf.layers.dense(en1, 64, tf.nn.tanh)
- encoded = tf.layers.dense(en2, 32)
- # decoder
- de0 = tf.layers.dense(encoded, 64, tf.nn.tanh)
- de1 = tf.layers.dense(de0, 128, tf.nn.tanh)
- de2 = tf.layers.dense(de1, 256, tf.nn.tanh)
- decoded = tf.layers.dense(de2, 28*28, tf.nn.sigmoid)
- loss = tf.losses.mean_squared_error(labels=tf_x, predictions=decoded)
- train = tf.train.AdamOptimizer(LR).minimize(loss)
- sess = tf.Session()
- sess.run(tf.global_variables_initializer())
- # initialize figure
- f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
- plt.ion() # continuously plot
- # original data (first row) for viewing
- view_data = mnist.test.images[:N_TEST_IMG]
- for i in range(N_TEST_IMG):
- a[0][i].imshow(np.reshape(view_data[i], (28, 28)), cmap='gray')
- a[0][i].set_xticks(())
- a[0][i].set_yticks(())
- for step in range(5000):
- b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
- _, encoded_, decoded_, loss_ = sess.run(
- [train, encoded, decoded, loss], {tf_x: b_x})
- if step % 100 == 0: # plotting
- print('train loss: %.4f' % loss_)
- # plotting decoded image (second row)
- decoded_data = sess.run(decoded, {tf_x: view_data})
- for i in range(N_TEST_IMG):
- a[1][i].clear()
- a[1][i].imshow(np.reshape(decoded_data[i], (28, 28)), cmap='gray')
- a[1][i].set_xticks(())
- a[1][i].set_yticks(())
- plt.draw()
- plt.pause(0.01)
- plt.ioff()
- for index in range(0, 0):
- decoded_data = decoded
- decoded_data_unnormalized = tf.round(decoded_data*255)
- decoded_data_uint8 = tf.cast(decoded_data_unnormalized, tf.uint8)
- decoded_data_uint8_reshaped = tf.reshape(decoded_data_uint8, [28, 28, 1])
- aux = tf.image.encode_jpeg(decoded_data_uint8_reshaped)
- writer = tf.write_file(
- 'train\\'+str(mnist.test.labels[index])+'\\'+str(index)+'.jpg', aux)
- if index % 10 == 0:
- print("Image number: {}".format(index))
- sess.run(writer, {tf_x: [mnist.train.images[index]]})
- for index in range(0, 500):
- decoded_data = decoded
- decoded_data_unnormalized = tf.round(decoded_data*255)
- decoded_data_uint8 = tf.cast(decoded_data_unnormalized, tf.uint8)
- decoded_data_uint8_reshaped = tf.reshape(decoded_data_uint8, [28, 28, 1])
- aux = tf.image.encode_jpeg(decoded_data_uint8_reshaped)
- writer = tf.write_file(
- 'test\\'+str(mnist.test.labels[index])+'\\'+str(index)+'.jpg', aux)
- if index % 10 == 0:
- print("Image number: {}".format(index))
- sess.run(writer, {tf_x: [mnist.test.images[index]]})
Add Comment
Please, Sign In to add comment