Advertisement
Guest User

Untitled

a guest
Apr 23rd, 2016
1,019
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.19 KB | None | 0 0
  1. import tensorflow as tf
  2. from tensorflow.models.rnn import rnn, rnn_cell
  3. from tensorflow.examples.tutorials.mnist import input_data
  4.  
  5. # Get MNIST data
  6. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  7.  
  8. # Parameters
  9. batch_size      = 100
  10. input_size      = 28 # ncols of image
  11. sequence_length = 28 # nrows of image
  12. n_classes       = 10
  13.  
  14. X = tf.placeholder("float", [None, sequence_length, input_size])
  15. Y = tf.placeholder("float", [None, n_classes])
  16.  
  17. # LSTM Cell
  18. lstm = rnn_cell.LSTMCell(num_units=200,
  19.                          forget_bias=1.0,
  20.                          initializer=tf.random_normal)
  21.  
  22. # Initial state
  23. istate = lstm.zero_state(batch_size, "float")
  24.  
  25. # Get lstm cell output
  26. output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)
  27.  
  28. # Output at last time point T
  29. output_at_T = output[:, 27, :]
  30.  
  31. network_output = output_at_T
  32.  
  33. # Launch the graph
  34. with tf.Session() as sess:
  35.     sess.run(tf.initialize_all_variables())
  36.     X_batch, Y_batch = mnist.train.next_batch(batch_size)
  37.     X_batch = X_batch.reshape((batch_size, sequence_length, input_size))
  38.     network_output_ = sess.run(network_output, feed_dict={X: X_batch, Y: Y_batch})
  39.     print network_output_
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement