Advertisement
Guest User

Untitled

a guest
Jul 21st, 2017
44
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.52 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import time
  4.  
  5. def network(input_list):
  6. input,init_hidden_c,init_hidden_m = input_list
  7. cell = tf.nn.rnn_cell.BasicLSTMCell(256, state_is_tuple=True)
  8. init_hidden = tf.nn.rnn_cell.LSTMStateTuple(init_hidden_c, init_hidden_m)
  9. states, hidden_cm = tf.nn.dynamic_rnn(cell, input, dtype=tf.float32, initial_state=init_hidden)
  10. net = [v for v in tf.trainable_variables()]
  11. return states, hidden_cm, net
  12.  
  13. def action(x, h_c, h_m):
  14. t0 = time.time()
  15. outputs, output_h = sess.run([rnn_states[:,-1:,:], rnn_hidden_cm], feed_dict={
  16. rnn_input:x,
  17. rnn_init_hidden_c: h_c,
  18. rnn_init_hidden_m: h_m
  19. })
  20. dt = time.time() - t0
  21. return outputs, output_h, dt
  22.  
  23. rnn_input = tf.placeholder("float", [None, None, 512])
  24. rnn_init_hidden_c = tf.placeholder("float", [None,256])
  25. rnn_init_hidden_m = tf.placeholder("float", [None,256])
  26. rnn_input_list = [rnn_input, rnn_init_hidden_c, rnn_init_hidden_m]
  27. rnn_states, rnn_hidden_cm, rnn_net = network(rnn_input_list)
  28.  
  29. feed_input = np.random.uniform(low=-1.,high=1.,size=(1,1,512))
  30. feed_init_hidden_c = np.zeros(shape=(1,256))
  31. feed_init_hidden_m = np.zeros(shape=(1,256))
  32.  
  33. sess = tf.Session()
  34. sess.run(tf.global_variables_initializer())
  35. for i in range(10000):
  36. _, output_hidden_cm, deltat = action(feed_input, feed_init_hidden_c, feed_init_hidden_m)
  37. if i % 10 == 0:
  38. print 'Running time: ' + str(deltat)
  39. (feed_init_hidden_c, feed_init_hidden_m) = output_hidden_cm
  40. feed_input = np.random.uniform(low=-1.,high=1.,size=(1,1,512))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement