Advertisement
Guest User

Untitled

a guest
Feb 26th, 2017
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.38 KB | None | 0 0
  1. class StandardCell:
  2. def __init__(self, name:str, cell_size, cell_count):
  3. self.scope = tf.VariableScope(False, name)
  4. with tf.variable_scope(self.scope):
  5. self.cell = rnn.MultiRNNCell([
  6. rnn.BasicLSTMCell(cell_size) for _ in range(cell_count)])
  7. self.std_input = tf.placeholder(tf.float32,
  8. [cell_size],
  9. 'std_input')
  10. self.std_input_state = tf.placeholder(tf.float32,
  11. [cell_count, 2, 1, cell_size],
  12. 'std_input_state')
  13. output, output_state = self.cell(
  14. tf.reshape(self.std_input, [1, cell_size]),
  15. [rnn.LSTMStateTuple(self.std_input_state[i][0],
  16. self.std_input_state[i][1])
  17. for i in range(cell_count)])
  18. self.std_output = tf.reshape(output, [cell_size], 'std_output')
  19. self.std_output_state = tf.concat([
  20. tf.concat([
  21. tf.reshape(output_state[i].c, [1, 1, 1, cell_size]),
  22. tf.reshape(output_state[i].h, [1, 1, 1, cell_size]),
  23. ], 1)
  24. for i in range(cell_count)
  25. ], 0, 'std_output_state')
  26.  
  27. def __call__(self, input, state):
  28. with tf
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement