Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- from keras import backend as K
- from keras import initializations, activations
- from keras.engine import InputSpec
- from keras.layers.recurrent import Recurrent
- # copied from the keras github source. removed lots of unnecessary (for me) code
- # assuming a 2D Convolution was run by hand before this layer.
- # please note that this has no variables of its own.
- # TODO: incorporate the 2D Convolution into this layer
- class CGRU(Recurrent):
- def __init__(self,
- init='glorot_uniform', inner_init='orthogonal',
- activation='tanh', inner_activation='hard_sigmoid', **kwargs):
- self.init = initializations.get(init)
- self.inner_init = initializations.get(inner_init)
- self.activation = activations.get(activation)
- self.inner_activation = activations.get(inner_activation)
- # removing the regularizers and the dropout
- super(CGRU, self).__init__(**kwargs)
- # TODO: it's not very elegant to just overwrite this. Maybe there is a better way ?
- # this seems necessary in order to accept 5 input dimensions
- # (samples, timesteps, features, x, y)
- self.input_spec = [InputSpec(ndim=5)]
- def build(self, input_shape):
- self.input_spec = [InputSpec(shape=input_shape)]
- self.input_dim = (3, input_shape[2], input_shape[3])
- # moved here from the constructor. The layer does not change the input dimension
- # it does change the number of features: 3 features are used for 3 gates, there is one output
- self.output_dim = [1, input_shape[2], input_shape[3]]
- if self.stateful:
- self.reset_states()
- else:
- # initial states: all-zero tensor of shape (output_dim)
- self.states = [None]
- def reset_states(self):
- # TODO: the state must become 2D. am I doing this right ?
- # TODO: assuming that the first dimension is batch_size, I'm now hardcoding for 2D images and th layout
- assert self.stateful, 'Layer must be stateful.'
- input_shape = self.input_spec[0].shape
- if not input_shape[0]:
- raise Exception('If a RNN is stateful, a complete ' +
- 'input_shape must be provided (including batch size).')
- if hasattr(self, 'states'):
- K.set_value(self.states[0],
- np.zeros((input_shape[0], self.input_shape[2], self.input_shape[3])))
- else:
- self.states = [K.zeros((input_shape[0], self.input_shape[2], self.input_shape[3]))]
- def preprocess_input(self, x):
- # here was a distinction between cpu and gpu. for starters I'll just use the gpu code
- return x
- def step(self, x, states):
- h_tm1 = states[0] # previous memory
- # the original code produces z, r, and hh.
- # I would like to use the values produced by the convolution instead
- # just drop all of the code above and slice the input
- # TODO: the three features per coordinate are a hard requirement. How can I enforce this ?
- z = self.inner_activation(x[:, 0, :, :])
- r = self.inner_activation(x[:, 1, :, :])
- hh = self.activation(x[:, 2, :, :])
- h = z * h_tm1 + (1 - z) * hh
- return h, [h]
- def get_constants(self, x):
- constants = []
- # TODO: this is copied from the GRU. I've no idea what exactly it does.
- constants.append([K.cast_to_floatx(1.) for _ in range(3)])
- constants.append([K.cast_to_floatx(1.) for _ in range(3)])
- return constants
- def get_initial_states(self, x):
- initial_state = K.zeros_like(x) # (samples, timesteps, input_dim) where input_dim = (3, x_dim, y_dim)
- initial_state = K.sum(initial_state, axis=(1, 2)) # (samples, x_dim, y_dim)
- return [initial_state]
- def get_output_shape_for(self, input_shape):
- # TODO: this is hardcoding for th layout
- return (input_shape[0], input_shape[1], 1, input_shape[3], input_shape[4])
- def get_config(self):
- config = {'output_dim': self.output_dim,
- 'init': self.init.__name__,
- 'inner_init': self.inner_init.__name__,
- 'activation': self.activation.__name__,
- 'inner_activation': self.inner_activation.__name__}
- # removed the various regularizers and dropouts.
- # surely this isn't needed if not present ?
- base_config = super(CGRU, self).get_config()
- return dict(list(base_config.items()) + list(config.items()))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement