Advertisement
Guest User

Untitled

a guest
Dec 3rd, 2016
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.99 KB | None | 0 0
  1. import tensorflow as tf
  2. import scipy.misc as smp
  3.  
  4. from tensorflow.examples.tutorials.mnist import input_data
  5.  
  6. def weight_variable(shape):
  7. initial = tf.truncated_normal(shape, stddev = 0.1)
  8. return tf.Variable(initial)
  9.  
  10. def bias_variable(shape):
  11. initial = tf.constant(0.1, shape = shape)
  12. return tf.Variable(initial)
  13.  
  14. def conv2d(x, W):
  15. return tf.nn.conv2d(
  16. x,
  17. W,
  18. strides = [1, 1, 1, 1],
  19. padding = 'SAME'
  20. )
  21.  
  22. def max_pool_2x2(x):
  23. return tf.nn.max_pool(
  24. x,
  25. ksize = [1, 2, 2, 1],
  26. strides = [1, 2, 2, 1],
  27. padding = 'VALID'
  28. )
  29.  
  30. def network(loadNetwork):
  31. mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
  32.  
  33. x = tf.placeholder(tf.float32, shape = [None, 784])
  34. y_ = tf.placeholder(tf.float32, shape = [None, 10])
  35.  
  36. x_image = tf.reshape(x, [-1, 28, 28, 1])
  37.  
  38. netName = "mnist_network"
  39.  
  40. with tf.name_scope("KeepProbability"):
  41. fckeep_prob = tf.placeholder(tf.float32)
  42. convkeep_prob = tf.placeholder(tf.float32)
  43.  
  44. with tf.name_scope("Conv1"):
  45. W_conv1 = weight_variable([3, 3, 1, 8])
  46. b_conv1 = bias_variable([8])
  47. h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
  48.  
  49. with tf.name_scope("Conv2"):
  50. W_conv2 = weight_variable([3, 3, 8, 16])
  51. b_conv2 = bias_variable([16])
  52. h_conv2 = tf.nn.relu(conv2d(h_conv1, W_conv2) + b_conv2)
  53.  
  54. with tf.name_scope("Pooling1"):
  55. h_pool1 = max_pool_2x2(h_conv2)
  56.  
  57. with tf.name_scope("Conv3"):
  58. W_conv3 = weight_variable([3, 3, 16, 24])
  59. b_conv3 = bias_variable([24])
  60. h_conv3 = tf.nn.relu(conv2d(h_pool1, W_conv3) + b_conv3)
  61.  
  62. with tf.name_scope("Pooling2"):
  63. h_pool2= max_pool_2x2(h_conv3)
  64.  
  65. with tf.name_scope("FC1"):
  66. h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 24])
  67. W_fc1 = weight_variable([7 * 7 * 24, 64])
  68. b_fc1 = bias_variable([64])
  69. h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
  70.  
  71. with tf.name_scope("DropOut2"):
  72. h_fc1_drop = tf.nn.dropout(h_fc1, fckeep_prob)
  73.  
  74. with tf.name_scope("FC2"):
  75. W_fc2 = weight_variable([64, 128])
  76. b_fc2 = bias_variable([128])
  77. h_fc2 = tf.nn.relu(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
  78.  
  79. with tf.name_scope("DropOut3"):
  80. h_fc2_drop = tf.nn.dropout(h_fc2, fckeep_prob)
  81.  
  82. with tf.name_scope("FC3"):
  83. W_fc3 = weight_variable([128, 10])
  84. b_fc3 = bias_variable([10])
  85.  
  86. y_conv = tf.nn.relu(tf.matmul(h_fc2_drop, W_fc3) + b_fc3)
  87.  
  88. with tf.name_scope("CrossEntropy"):
  89. cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv, y_))
  90.  
  91. with tf.name_scope("TrainingStep"):
  92. train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
  93.  
  94. with tf.name_scope("Accuracy"):
  95. correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
  96. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  97.  
  98. tf.scalar_summary('accuracy', accuracy)
  99. tf.scalar_summary('cross_entropy', cross_entropy)
  100.  
  101. sess = tf.InteractiveSession()
  102.  
  103. merged = tf.merge_all_summaries()
  104. tensorLogDir = "./tensorLogs"
  105. train_writer = tf.train.SummaryWriter(tensorLogDir, sess.graph)
  106.  
  107. sess.run(tf.initialize_all_variables())
  108.  
  109. numPasses = 10000
  110. printNum = numPasses / 10
  111.  
  112. saver = tf.train.Saver()
  113.  
  114. if not loadNetwork:
  115. for i in range(numPasses):
  116. batch = mnist.train.next_batch(100)
  117. if i % printNum == 0:
  118. train_accuracy = accuracy.eval(
  119. feed_dict={
  120. x: batch[0],
  121. y_: batch[1],
  122. fckeep_prob: 1.0,
  123. convkeep_prob: 1.0
  124. })
  125.  
  126. print("%s: step %d, training accuracy %g"%(netName, i, train_accuracy))
  127.  
  128. saver.save(sess, "./saved/{txt}".format(txt = netName), global_step = i)
  129.  
  130. _, pool1, pool2, summary = sess.run(
  131. [train_step,
  132. h_pool1,
  133. h_pool2,
  134. merged],
  135. feed_dict={
  136. x: batch[0],
  137. y_: batch[1],
  138. fckeep_prob: 0.5,
  139. convkeep_prob: 0.8
  140. })
  141.  
  142. train_writer.add_summary(summary, i)
  143.  
  144. saver.save(sess, "./saved/{txt}".format(txt = netName), global_step = numPasses)
  145.  
  146. for i in range(pool1.shape[0]):
  147. for j in range(pool1.shape[3]):
  148. img = smp.toimage(pool1[i,:,:,j])
  149. smp.imsave("./images14/img{i}_{j}.png".format(
  150. i = i,
  151. j = j
  152. ), img)
  153.  
  154. for i in range(pool2.shape[0]):
  155. for j in range(pool2.shape[3]):
  156. img = smp.toimage(pool2[i,:,:,j])
  157. smp.imsave("./images7/img{i}_{j}.png".format(
  158. i = i,
  159. j = j
  160. ), img)
  161.  
  162. i = 0
  163. for data in batch[0]:
  164. img = smp.toimage(data.reshape((28, 28)))
  165. smp.imsave("./inputs/img{i}.png".format(
  166. i = i,
  167. ), img)
  168. i += 1
  169.  
  170. print("test accuracy %g"%accuracy.eval(feed_dict={
  171. x: mnist.test.images,
  172. y_: mnist.test.labels,
  173. fckeep_prob: 1.0,
  174. convkeep_prob: 1.0
  175. }))
  176.  
  177. else:
  178. saver.restore(sess, "./saved/{txt}-{np}".format(
  179. txt = netName,
  180. np = numPasses
  181. ))
  182. print("Model restored.")
  183. print("Starting tests...")
  184. print("test accuracy %g"%accuracy.eval(feed_dict={
  185. x: mnist.test.images,
  186. y_: mnist.test.labels,
  187. fckeep_prob: 1.0,
  188. convkeep_prob: 1.0
  189. }))
  190.  
  191. print(netName, "finished!")
  192.  
  193. def trainNetwork():
  194. network(loadNetwork = False)
  195. print("Network successfully trained.")
  196.  
  197. def testNetwork():
  198. network(loadNetwork = True)
  199.  
  200. def main():
  201. trainNetwork()
  202. #testNetwork()
  203.  
  204. if __name__ == '__main__':
  205. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement