Guest User

Untitled

a guest
May 26th, 2018
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.64 KB | None | 0 0
  1. #!/usr/bin/env python2
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Sun May 20 15:13:38 2018
  5.  
  6. @author: akber
  7. """
  8.  
  9. import tensorflow as tf
  10. import numpy as np
  11. import matplotlib.pyplot as plt
  12.  
  13. from tensorflow.examples.tutorials.mnist import input_data
  14. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  15.  
  16. learning_rate = 0.01
  17. num_steps = 480*10
  18. batch_size = 125
  19. test_batch_size = 100
  20.  
  21. display_step = 480
  22. examples_to_show = 10
  23.  
  24. # Network Parameters
  25. num_input = 784 # MNIST data input (img shape: 28*28)
  26. training_loss=[]
  27. error=[]
  28. perf = []
  29. rate = np.array([10**-4, 10**-3, 10**-2, 10**-1, 10**-0])
  30.  
  31. X = tf.placeholder("float", [None, num_input])
  32. Y = tf.placeholder(tf.float32, [None, 10])
  33. beta = tf.placeholder(tf.float32, shape=())
  34. data_type = tf.placeholder(tf.int16, shape=())
  35.  
  36. def weight_variable(shape, name):
  37. # From the mnist tutorial
  38. initial = tf.truncated_normal(shape, stddev=0.1)
  39. return tf.Variable(initial, name=name)
  40.  
  41. def bias_variable(shape, name):
  42. initial = tf.constant(0.1, shape=shape)
  43. return tf.Variable(initial, name=name)
  44.  
  45.  
  46. def fc_layer(previous, input_size, output_size, name):
  47. W = weight_variable([input_size, output_size], name)
  48. b = bias_variable([output_size], name)
  49. return tf.matmul(previous, W) + b
  50.  
  51.  
  52. def autoencoder(x,y,b,dtype):
  53. l1 = tf.nn.tanh(fc_layer(x, 28*28, 300, "auto"))
  54. l2 = tf.nn.tanh(fc_layer(l1, 300, 60, "auto"))
  55. l3 = fc_layer(l2, 60, 30, "auto")
  56. l4 = tf.nn.tanh(fc_layer(l3, 30, 60, "auto"))
  57. l5 = tf.nn.tanh(fc_layer(l4, 60, 300, "auto"))
  58. out = tf.nn.relu(fc_layer(l5, 300, 28*28, "auto"))
  59. loss = tf.reduce_mean(tf.squared_difference(x, out)) + b*tf.reduce_mean(tf.squared_difference(dct2(x,dtype),dct2(out,dtype)))
  60. mse = tf.reduce_mean(tf.squared_difference(x, out))
  61. return loss, mse, out, l3
  62.  
  63. def dct2(x,dtype):
  64. T = []
  65. if dtype==1:
  66. N = batch_size
  67. else:
  68. N = test_batch_size
  69. for i in range(N):
  70. T = tf.concat([T,tf.reshape(tf.spectral.dct(tf.transpose(tf.spectral.dct(tf.transpose(tf.reshape(x[i,:],[28,28]))))),[-1])],0)
  71. T = tf.reshape
  72. loss, mse, output, latent = autoencoder(X, Y, beta, data_type)
  73.  
  74. # and we use the Adam Optimizer for training
  75.  
  76. optimizer = tf.train.AdamOptimizer(1e-4).minimize(loss)
  77. # Initialize the variables (i.e. assign their default value)
  78. init = tf.global_variables_initializer()
  79. #init2 = tf.initializers.variables(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="auto"))
  80. with tf.Session() as sess:
  81. # Run the initializer
  82. sess.run(init)
  83. for t in rate:
  84. sess.run(init)
  85. for i in range(1, num_steps+1):
  86. # Prepare Data
  87. # Get the next batch of MNIST data (only images are needed, not labels)
  88. batch_x, batch_y = mnist.train.next_batch(batch_size)
  89. #batch_x, batch_y = mnist.train.next_batch(1)
  90.  
  91. # Run optimization op (backprop) and cost op (to get loss value)
  92. _, l = sess.run([optimizer, loss], feed_dict={X: batch_x, Y: batch_y, beta: t, data_type: 1})
  93. # Display logs per step
  94. if i % display_step == 0 or i == 1:
  95. print('Step %i: Minibatch Loss: %f' % (i, l))
  96. #training_loss.append(l)
  97.  
  98. n = 100
  99. error = []
  100. for i in range(n):
  101. batch_x, _ = mnist.test.next_batch(test_batch_size)
  102. #batch_x, _ = mnist.test.next_batch(1)
  103. l = sess.run(mse, feed_dict={X: batch_x, Y: batch_y, beta: 0, data_type: 0})
  104. error.append(l)
  105.  
  106. #plt.plot(error)
  107. #plt.show()
  108. perf.append(sum(error)/n)
  109.  
  110. plt.semilogx(rate,perf)
  111. plt.xlabel('$\lambda$')
  112. plt.ylabel('MSE')
  113. plt.show
Add Comment
Please, Sign In to add comment