Advertisement
Guest User

Untitled

a guest
Aug 18th, 2017
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.33 KB | None | 0 0
  1. class LSTM_Cell:
  2. def __init__(self, time_steps=None, input_size=None, cell_size=None, learning_rate=1e-3):
  3. assert cell_size is not None and input_size is not None
  4.  
  5. self.input_size = input_size
  6. self.cell_size = cell_size
  7.  
  8. self.inputs = tf.placeholder(tf.float32, shape=[None, None, input_size], name="lstm_inputs")
  9. self.targets = tf.placeholder(tf.float32, shape=[None, None, input_size], name="lstm_targets")
  10.  
  11. self.w_f = tf.Variable(tf.random_normal(shape=[self.input_size, self.cell_size]), name="w_f")
  12. self.u_f = tf.Variable(tf.random_normal(shape=[self.cell_size, self.cell_size]), name="u_f")
  13. self.b_f = tf.Variable(tf.constant(0., shape=[self.cell_size]), name="b_f")
  14.  
  15. self.w_i = tf.Variable(tf.random_normal(shape=[self.input_size, self.cell_size]), name="w_i")
  16. self.u_i = tf.Variable(tf.random_normal(shape=[self.cell_size, self.cell_size]), name="u_i")
  17. self.b_i = tf.Variable(tf.constant(0., shape=[self.cell_size]), name="b_i")
  18.  
  19. self.w_o = tf.Variable(tf.random_normal(shape=[self.input_size, self.cell_size]), name="w_o")
  20. self.u_o = tf.Variable(tf.random_normal(shape=[self.cell_size, self.cell_size]), name="u_o")
  21. self.b_o = tf.Variable(tf.constant(0., shape=[self.cell_size]), name="b_0")
  22.  
  23. self.w_c = tf.Variable(tf.random_normal(shape=[self.input_size, self.cell_size]), name="w_c")
  24. self.u_c = tf.Variable(tf.random_normal(shape=[self.cell_size, self.cell_size]), name="u_cs")
  25. self.b_c = tf.Variable(tf.constant(0., shape=[self.cell_size]), name="b_c")
  26.  
  27. self.learning_rate = learning_rate
  28.  
  29. self.outputs = None
  30. self.last_hidden_state = None
  31. self.last_cell_state = None
  32.  
  33. self.time_steps = time_steps
  34.  
  35. def call(self, state_tuple, x):
  36. """
  37. One iteration of the LSTM cell.
  38.  
  39. params:
  40. state_tuple: The previous hidden and cell state (of the shape [hidden or cell, batch_size])
  41. x: The batch input into the neural network (of the shape [batch_size, input_size])
  42. Example: [[3, 2, 1], [1, 2, 3]] -> Where each row represents a item in the batch
  43. returns:
  44. A new state tuple representing the new hidden and cel state
  45. """
  46. previous_hidden, previous_cell_state = tf.unstack(state_tuple)
  47.  
  48. f = tf.nn.sigmoid(tf.matmul(x, self.w_f) + tf.matmul(previous_hidden, self.u_f) + self.b_f)
  49. i = tf.nn.sigmoid(tf.matmul(x, self.w_i) + tf.matmul(previous_hidden, self.u_i) + self.b_i)
  50. o = tf.nn.sigmoid(tf.matmul(x, self.w_o) + tf.matmul(previous_hidden, self.u_o) + self.b_o)
  51.  
  52. cell_state_additions = tf.nn.tanh(tf.matmul(x, self.w_c) + tf.matmul(x, self.w_c) + self.b_c)
  53.  
  54. cell_state = tf.multiply(cell_state_additions, i) + tf.multiply(previous_cell_state, f)
  55. new_hidden = tf.nn.tanh(tf.multiply(cell_state, o))
  56.  
  57. return tf.stack([new_hidden, cell_state])
  58.  
  59. def dynamic_rnn(self, input_sequence=None, dynamic_output=False, initial_state_tuple=None):
  60. """
  61. Given a batch major input sequence it will run convert it to time major and run through
  62. the LSTM cell step by step.
  63.  
  64. params:
  65. input_sequence: A batch major input sequence (of the shape [batch_size, sequence_length, input_size])
  66. initial_state_tuple: A state tuple that should be used otherwise a zero one will be used.
  67.  
  68. returns:
  69. The hidden states and cell states for each batch item for each time step (returned batch_major and
  70. not time major).
  71. """
  72. input_sequence = tf.transpose(input_sequence, [1, 0, 2]) #tranpose to allow for batch processing
  73.  
  74. batch_items = tf.shape(input_sequence)[1]
  75.  
  76. if initial_state_tuple is None:
  77. initial_state_tuple = self.initial_state_tuple(batch_items)
  78.  
  79. state_tuples = tf.scan(self.call, input_sequence, initializer=initial_state_tuple)
  80.  
  81. hidden_states, cell_states = self.split_state_tuples(state_tuples)
  82. last_hidden_states, last_cell_states = tf.unstack(tf.gather(state_tuples, tf.shape(state_tuples)[1]))
  83.  
  84. return hidden_states, cell_states, tf.stack([last_hidden_states, last_cell_states])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement