Guest User

Untitled

a guest
Jun 20th, 2018
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.92 KB | None | 0 0
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. from tensorflow.python.ops import rnn, rnn_cell
  4. mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)
  5.  
  6. hm_epochs = 3
  7. n_classes = 10
  8. batch_size = 128
  9. chunk_size = 28
  10. n_chunks = 28
  11. rnn_size = 128
  12.  
  13.  
  14. x = tf.placeholder('float', [None, n_chunks,chunk_size])
  15. y = tf.placeholder('float')
  16.  
  17. def recurrent_neural_network(x):
  18. layer = {'weights':tf.Variable(tf.random_normal([rnn_size,n_classes])),
  19. 'biases':tf.Variable(tf.random_normal([n_classes]))}
  20.  
  21. x = tf.transpose(x, [1,0,2])
  22. x = tf.reshape(x, [-1, chunk_size])
  23. x = tf.split(x, n_chunks, 0)
  24.  
  25. lstm_cell = rnn_cell.BasicLSTMCell(rnn_size,state_is_tuple=True)
  26. outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
  27.  
  28. output = tf.matmul(outputs[-1],layer['weights']) + layer['biases']
  29.  
  30. return output
  31.  
  32. def train_neural_network(x):
  33. prediction = recurrent_neural_network(x)
  34. cost = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction) )
  35. optimizer = tf.train.AdamOptimizer().minimize(cost)
  36.  
  37.  
  38. with tf.Session() as sess:
  39. sess.run(tf.initialize_all_variables())
  40.  
  41. for epoch in range(hm_epochs):
  42. epoch_loss = 0
  43. for _ in range(int(mnist.train.num_examples/batch_size)):
  44. epoch_x, epoch_y = mnist.train.next_batch(batch_size)
  45. epoch_x = epoch_x.reshape((batch_size,n_chunks,chunk_size))
  46.  
  47. _, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y})
  48. epoch_loss += c
  49.  
  50. print('Epoch', epoch, 'completed out of',hm_epochs,'loss:',epoch_loss)
  51.  
  52. correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
  53.  
  54. accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
  55. print('Accuracy:',accuracy.eval({x:mnist.test.images.reshape((-1, n_chunks, chunk_size)), y:mnist.test.labels}))
  56.  
  57. train_neural_network(x)
Add Comment
Please, Sign In to add comment