Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- from tensorflow.models.rnn import rnn, rnn_cell
- from tensorflow.examples.tutorials.mnist import input_data
- # Get MNIST data
- mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
- # Parameters
- batch_size = 100
- input_size = 28 # ncols of image
- sequence_length = 28 # nrows of image
- n_classes = 10
- X = tf.placeholder("float", [None, sequence_length, input_size])
- Y = tf.placeholder("float", [None, n_classes])
- # LSTM Cell
- lstm = rnn_cell.LSTMCell(num_units=200,
- forget_bias=1.0,
- initializer=tf.random_normal)
- # Initial state
- istate = lstm.zero_state(batch_size, "float")
- # Get lstm cell output
- output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)
- # Output at last time point T
- output_at_T = output[:, 27, :]
- network_output = output_at_T
- # Launch the graph
- with tf.Session() as sess:
- sess.run(tf.initialize_all_variables())
- X_batch, Y_batch = mnist.train.next_batch(batch_size)
- X_batch = X_batch.reshape((batch_size, sequence_length, input_size))
- network_output_ = sess.run(network_output, feed_dict={X: X_batch, Y: Y_batch})
- print network_output_
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement