Guest User

Untitled

a guest
Mar 18th, 2018
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.71 KB | None | 0 0
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3.  
  4. keep_prob=tf.placeholder(tf.float32)
  5.  
  6. def evaluation(logits, labels):
  7. correct = tf.nn.in_top_k(logits, labels, 1)
  8. return tf.reduce_sum(tf.cast(correct, tf.int32))
  9.  
  10. mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
  11. xs=tf.placeholder(tf.float32,[None,784])
  12. ys=tf.placeholder(tf.float32,[None,10])
  13.  
  14. x_image = tf.reshape(xs,[-1,28,28,1])
  15.  
  16.  
  17. #===================================================
  18. def weight_v(shape):
  19. initial = tf.truncated_normal(shape,stddev=0.1)
  20. return tf.Variable(initial)
  21.  
  22. def bias_v(shape):
  23. initial = tf.constant(0.1,shape=shape)
  24. return tf.Variable(initial)
  25.  
  26. def conv2d(x,W):
  27. return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')
  28.  
  29. def max_pool_2x2(x):
  30. return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
  31. #----------------------L1--------------------------
  32. W_conv1 = weight_v([5,5,1,64])
  33. b_conv1 = bias_v([64])
  34.  
  35. hid_conv1 = tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
  36. hid_pool1 = max_pool_2x2(hid_conv1)
  37. #----------------------L2---------------------------
  38. W_conv2 = weight_v([5,5,64,128])
  39. b_conv2 = bias_v([128])
  40.  
  41. hid_conv2 = tf.nn.relu(conv2d(hid_pool1,W_conv2)+b_conv2)
  42. hid_pool2 = max_pool_2x2(hid_conv2)
  43. #---------------------------------------------------
  44. W_fc1 = weight_v([7*7*128,1024])
  45. b_fc1 = bias_v([1024])
  46.  
  47. hid_pool2flat = tf.reshape(hid_pool2,[-1,7*7*128])
  48. hid_fc1 = tf.nn.relu(tf.matmul(hid_pool2flat,W_fc1)+b_fc1)
  49. hid_fc1_dropout = tf.nn.dropout(hid_fc1,keep_prob=1.0)
  50.  
  51. #---------------------------------------------------
  52. W_fc2 = weight_v([1024,10])
  53. b_fc2 = bias_v([10])
  54. prediction = tf.nn.softmax(tf.matmul(hid_fc1_dropout,W_fc2)+b_fc2)
  55. #=================================================================
  56. cross_entropy=tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1]))
  57. #loss
  58. train=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  59. #saver = tf.train.Saver()
  60. #===================================================
  61. def accuracy(v_xs,v_ys):
  62. global prediction
  63. y_pre=sess.run(prediction,feed_dict={xs:v_xs,keep_prob:1.0})
  64. judge=tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))
  65. acc=tf.reduce_mean(tf.cast(judge,tf.float32))
  66. result=sess.run(acc,feed_dict={xs:v_xs,ys:v_ys,keep_prob:1.0})
  67. return result
  68. #===================================================
  69.  
  70.  
  71. with tf.Session() as sess:
  72. sess.run(tf.global_variables_initializer())
  73. for i in range(0,1000):
  74. batch_xs,batch_ys=mnist.train.next_batch(50)
  75. sess.run(train,feed_dict={xs:batch_xs,ys:batch_ys,keep_prob:0.5})
  76. if i%20==0:
  77. print(i,accuracy(mnist.test.images[:1000],mnist.test.labels[:1000]))
  78. # save_path = saver.save(sess,"/home/ky/test/model/cnn.ckpt")
Add Comment
Please, Sign In to add comment