Advertisement
Guest User

Untitled

a guest
May 24th, 2017
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.01 KB | None | 0 0
  1. import numpy as np
  2. from random import shuffle
  3. import tensorflow as tf
  4.  
  5. def generate_sequence_data(batch_size, sequence_length):
  6. # sequence data is 60 inputs long and each input is of length 4
  7. # batch is length 10
  8. train_input = []
  9. for i in range(batch_size):
  10. batch = []
  11. for j in range(sequence_length):
  12. inp = [1, 0, 0 , 0]
  13. batch.append(inp)
  14. train_input.append(batch)
  15. return train_input
  16.  
  17.  
  18. def run():
  19. # this function generates input and creates a very basic graph
  20. # the graph consists of an input placeholder and the LSTM cell
  21. #
  22. batch_size = 1
  23. num_hidden = 1
  24. sequence_length = 10
  25. input_length = 4
  26. train_input = generate_sequence_data(batch_size, sequence_length)
  27.  
  28. # PHASE 1 - build the computation graph
  29. # create a placeholder for the inputs
  30. input_placeholder = tf.placeholder(tf.float32, shape=(batch_size, sequence_length, input_length))
  31. # create the RNN cell
  32. cell = tf.contrib.rnn.BasicLSTMCell(num_hidden)
  33. # output is the outputs of the cell for each input
  34. # state is the final state of the cell
  35. output, state = tf.nn.dynamic_rnn(cell, input_placeholder, dtype=tf.float32)
  36. # PHASE 2 - run the graph using DNA sequences that we generated above
  37. init_op = tf.global_variables_initializer()
  38. sess = tf.Session()
  39. sess.run(init_op)
  40. inp = train_input[0:batch_size]
  41. # returns the values for each hidden unit at each step of computation
  42. # this returns an array of length batch_size
  43. # each element of that array is an array of length sequence_length
  44. # each element of that array is an array of length num_hidden
  45.  
  46. outputs = sess.run(output, {input_placeholder: inp})
  47. print("output values: ")
  48. print(outputs)
  49. # returns a LSTMStateTuple where c = cell state and h = output value
  50. states = sess.run(state,{input_placeholder: inp})
  51. print("internal states: ")
  52. print(states)
  53.  
  54.  
  55. if __name__ == "__main__":
  56. # arguments are
  57. # num_hidden, normalization method, max or sum
  58. run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement