Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class StandardCell:
- def __init__(self, name:str, cell_size, cell_count):
- self.scope = tf.VariableScope(False, name)
- with tf.variable_scope(self.scope):
- self.cell = rnn.MultiRNNCell([
- rnn.BasicLSTMCell(cell_size) for _ in range(cell_count)])
- self.std_input = tf.placeholder(tf.float32,
- [cell_size],
- 'std_input')
- self.std_input_state = tf.placeholder(tf.float32,
- [cell_count, 2, 1, cell_size],
- 'std_input_state')
- output, output_state = self.cell(
- tf.reshape(self.std_input, [1, cell_size]),
- [rnn.LSTMStateTuple(self.std_input_state[i][0],
- self.std_input_state[i][1])
- for i in range(cell_count)])
- self.std_output = tf.reshape(output, [cell_size], 'std_output')
- self.std_output_state = tf.concat([
- tf.concat([
- tf.reshape(output_state[i].c, [1, 1, 1, cell_size]),
- tf.reshape(output_state[i].h, [1, 1, 1, cell_size]),
- ], 1)
- for i in range(cell_count)
- ], 0, 'std_output_state')
- def __call__(self, input, state):
- with tf
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement