Guest User

Untitled

a guest
Oct 16th, 2017
392
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.83 KB | None | 0 0
  1. import tensorflow as tf
  2.  
  3. class ConvLSTMCell(tf.nn.rnn_cell.RNNCell):
  4. """A LSTM cell with convolutions instead of multiplications.
  5. Reference:
  6. Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015.
  7. """
  8.  
  9. def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, normalize=True, peephole=True, data_format='channels_last', reuse=None):
  10. super(ConvLSTMCell, self).__init__(_reuse=reuse)
  11. self._kernel = kernel
  12. self._filters = filters
  13. self._forget_bias = forget_bias
  14. self._activation = activation
  15. self._normalize = normalize
  16. self._peephole = peephole
  17. if data_format == 'channels_last':
  18. self._size = tf.TensorShape(shape + [self._filters])
  19. self._feature_axis = self._size.ndims
  20. self._data_format = None
  21. elif data_format == 'channels_first':
  22. self._size = tf.TensorShape([self._filters] + shape)
  23. self._feature_axis = 0
  24. self._data_format = 'NC'
  25. else:
  26. raise ValueError('Unknown data_format')
  27.  
  28. @property
  29. def state_size(self):
  30. return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size)
  31.  
  32. @property
  33. def output_size(self):
  34. return self._size
  35.  
  36. def call(self, x, state):
  37. c, h = state
  38.  
  39. x = tf.concat([x, h], axis=self._feature_axis)
  40. n = x.shape[-1].value
  41. m = 4 * self._filters if self._filters > 1 else 4
  42. W = tf.get_variable('kernel', self._kernel + [n, m])
  43. y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format)
  44. if not self._normalize:
  45. y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
  46. j, i, f, o = tf.split(y, 4, axis=self._feature_axis)
  47.  
  48. if self._peephole:
  49. i += tf.get_variable('W_ci', c.shape[1:]) * c
  50. f += tf.get_variable('W_cf', c.shape[1:]) * c
  51.  
  52. if self._normalize:
  53. j = tf.contrib.layers.layer_norm(j)
  54. i = tf.contrib.layers.layer_norm(i)
  55. f = tf.contrib.layers.layer_norm(f)
  56.  
  57. f = tf.sigmoid(f + self._forget_bias)
  58. i = tf.sigmoid(i)
  59. c = c * f + i * self._activation(j)
  60.  
  61. if self._peephole:
  62. o += tf.get_variable('W_co', c.shape[1:]) * c
  63.  
  64. if self._normalize:
  65. o = tf.contrib.layers.layer_norm(o)
  66. c = tf.contrib.layers.layer_norm(c)
  67.  
  68. o = tf.sigmoid(o)
  69. h = o * self._activation(c)
  70.  
  71. # TODO
  72. #tf.summary.histogram('forget_gate', f)
  73. #tf.summary.histogram('input_gate', i)
  74. #tf.summary.histogram('output_gate', o)
  75. #tf.summary.histogram('cell_state', c)
  76.  
  77. state = tf.nn.rnn_cell.LSTMStateTuple(c, h)
  78.  
  79. return h, state
  80.  
  81.  
  82. class ConvGRUCell(tf.nn.rnn_cell.RNNCell):
  83. """A GRU cell with convolutions instead of multiplications."""
  84.  
  85. def __init__(self, shape, filters, kernel, activation=tf.tanh, normalize=True, data_format='channels_last', reuse=None):
  86. super(ConvGRUCell, self).__init__(_reuse=reuse)
  87. self._filters = filters
  88. self._kernel = kernel
  89. self._activation = activation
  90. self._normalize = normalize
  91. if data_format == 'channels_last':
  92. self._size = tf.TensorShape(shape + [self._filters])
  93. self._feature_axis = self._size.ndims
  94. self._data_format = None
  95. elif data_format == 'channels_first':
  96. self._size = tf.TensorShape([self._filters] + shape)
  97. self._feature_axis = 0
  98. self._data_format = 'NC'
  99. else:
  100. raise ValueError('Unknown data_format')
  101.  
  102. @property
  103. def state_size(self):
  104. return self._size
  105.  
  106. @property
  107. def output_size(self):
  108. return self._size
  109.  
  110. def call(self, x, h):
  111. channels = x.shape[self._feature_axis].value
  112.  
  113. with tf.variable_scope('gates'):
  114. inputs = tf.concat([x, h], axis=self._feature_axis)
  115. n = channels + self._filters
  116. m = 2 * self._filters if self._filters > 1 else 2
  117. W = tf.get_variable('kernel', self._kernel + [n, m])
  118. y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
  119. if self._normalize:
  120. r, u = tf.split(y, 2, axis=self._feature_axis)
  121. r = tf.contrib.layers.layer_norm(r)
  122. u = tf.contrib.layers.layer_norm(u)
  123. else:
  124. y += tf.get_variable('bias', [m], initializer=tf.ones_initializer())
  125. r, u = tf.split(y, 2, axis=self._feature_axis)
  126. r, u = tf.sigmoid(r), tf.sigmoid(u)
  127.  
  128. # TODO
  129. #tf.summary.histogram('reset_gate', r)
  130. #tf.summary.histogram('update_gate', u)
  131.  
  132. with tf.variable_scope('candidate'):
  133. inputs = tf.concat([x, r * h], axis=self._feature_axis)
  134. n = channels + self._filters
  135. m = self._filters
  136. W = tf.get_variable('kernel', self._kernel + [n, m])
  137. y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
  138. if self._normalize:
  139. y = tf.contrib.layers.layer_norm(y)
  140. else:
  141. y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
  142. h = u * h + (1 - u) * self._activation(y)
  143.  
  144. return h, h
Add Comment
Please, Sign In to add comment