Advertisement
nordlaender

Keras CGRU

Nov 2nd, 2016
822
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.98 KB | None | 0 0
  1. import matplotlib
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4.  
  5. from keras.models import Sequential, Model
  6. from keras.layers import Dense, Dropout, Activation, Flatten, Input, merge
  7. from keras.layers import Convolution2D, MaxPooling2D, Reshape, Convolution3D, MaxPooling3D
  8. from keras.layers import BatchNormalization, AveragePooling3D, UpSampling3D, Cropping3D, Merge
  9. from keras.layers import SpatialDropout3D
  10. from keras.layers.noise import GaussianNoise
  11. from keras.regularizers import l1, l2, l1l2
  12. from keras.utils import np_utils
  13. from keras import backend as K
  14. from keras.optimizers import RMSprop, Adam
  15. import numpy as np
  16. from keras.layers.recurrent import Recurrent
  17. from keras import initializations, activations, regularizers
  18. from keras.engine import InputSpec
  19.  
  20. # copied from the keras github source. removed lots of unnecessary (for me) code
  21.  
  22. # assuming a 2D Convolution was run by hand before this layer.
  23. # please note that this has no variables of its own.
  24. # TODO: incorporate the 2D Convolution into this layer
  25.  
  26. class CGRU(Recurrent):
  27.     def __init__(self,
  28.                  init='glorot_uniform', inner_init='orthogonal',
  29.                  activation='tanh', inner_activation='hard_sigmoid', **kwargs):
  30.  
  31.         self.init = initializations.get(init)
  32.         self.inner_init = initializations.get(inner_init)
  33.         self.activation = activations.get(activation)
  34.         self.inner_activation = activations.get(inner_activation)
  35.  
  36.         #removing the regularizers and the dropout
  37.  
  38.         super(CGRU, self).__init__(**kwargs)
  39.  
  40.         # this seems necessary in order to accept 5 input dimensions
  41.         # (samples, timesteps, features, x, y)
  42.         self.input_spec=[InputSpec(ndim=5)]
  43.  
  44.     def build(self, input_shape):
  45.         self.input_spec = [InputSpec(shape=input_shape)]
  46.         self.input_dim = (3,input_shape[2],input_shape[3])
  47.  
  48.         # moved here from the constructor. Curiously it seems like batch_size has been removed from here.
  49.         self.output_dim=[1,input_shape[2], input_shape[3]]
  50.  
  51.         if self.stateful:
  52.             self.reset_states()
  53.         else:
  54.             # initial states: all-zero tensor of shape (output_dim)
  55.             self.states = [None]
  56.  
  57.     def reset_states(self):
  58.         # TODO: the state must become 2D. am I doing this right ?
  59.         # TODO: assuming that the first dimension is batch_size, I'm now hardcoding for 2D images and th layout
  60.  
  61.         assert self.stateful, 'Layer must be stateful.'
  62.         input_shape = self.input_spec[0].shape
  63.         if not input_shape[0]:
  64.             raise Exception('If a RNN is stateful, a complete ' +
  65.                             'input_shape must be provided (including batch size).')
  66.         if hasattr(self, 'states'):
  67.             K.set_value(self.states[0],
  68.                         np.zeros((input_shape[0], self.input_shape[2], self.input_shape[2])))
  69.         else:
  70.             self.states = [K.zeros((input_shape[0], self.input_shape[2], self.input_shape[2]))]
  71.  
  72.     def preprocess_input(self, x):
  73.         # here was a distinction between cpu and gpu. for starters I'll just use the cpu code
  74.         return x
  75.  
  76.     def step(self, x, states):
  77.         h_tm1 = states[0]  # previous memory
  78.  
  79.         #TODO: I've got no idea where the states are set. Maybe in the superclass ?
  80.         #TODO: debug to see how many entries there really are in the states variable
  81.         #B_U = states[1]  # dropout matrices for recurrent units
  82.         #B_W = states[2]
  83.  
  84.         #TODO: now I need to use the features from the Convolution.#
  85.         # Since I'll hardcode for th layout, the x will (hopefully) look like:
  86.         # [batch, features, x_dim, y_dim]
  87.         # note: slicing is possible, just always use : the entire first dimension (batch)
  88.         """
  89.        if self.consume_less == 'gpu':
  90.  
  91.            matrix_x = K.dot(x * B_W[0], self.W) + self.b
  92.            matrix_inner = K.dot(h_tm1 * B_U[0], self.U[:, :2 * self.output_dim])
  93.  
  94.            x_z = matrix_x[:, :self.output_dim]
  95.            x_r = matrix_x[:, self.output_dim: 2 * self.output_dim]
  96.            inner_z = matrix_inner[:, :self.output_dim]
  97.            inner_r = matrix_inner[:, self.output_dim: 2 * self.output_dim]
  98.  
  99.            z = self.inner_activation(x_z + inner_z)
  100.            r = self.inner_activation(x_r + inner_r)
  101.  
  102.            x_h = matrix_x[:, 2 * self.output_dim:]
  103.            inner_h = K.dot(r * h_tm1 * B_U[0], self.U[:, 2 * self.output_dim:])
  104.            hh = self.activation(x_h + inner_h)
  105.        else:
  106.            if self.consume_less == 'cpu':
  107.                x_z = x[:, :self.output_dim]
  108.                x_r = x[:, self.output_dim: 2 * self.output_dim]
  109.                x_h = x[:, 2 * self.output_dim:]
  110.            elif self.consume_less == 'mem':
  111.                x_z = K.dot(x * B_W[0], self.W_z) + self.b_z
  112.                x_r = K.dot(x * B_W[1], self.W_r) + self.b_r
  113.                x_h = K.dot(x * B_W[2], self.W_h) + self.b_h
  114.            else:
  115.                raise Exception('Unknown `consume_less` mode.')
  116.            z = self.inner_activation(x_z + K.dot(h_tm1 * B_U[0], self.U_z))
  117.            r = self.inner_activation(x_r + K.dot(h_tm1 * B_U[1], self.U_r))
  118.  
  119.            hh = self.activation(x_h + K.dot(r * h_tm1 * B_U[2], self.U_h))
  120.        """
  121.         #all the code above produces z, r, and hh.
  122.         # I would like to use the values produced by the convolution instead
  123.         # just drop all of the code above and slice the input
  124.  
  125.         #TODO: add the activations here
  126.         z=self.inner_activation(x[:,0,:,:])
  127.         r=self.inner_activation(x[:,1,:,:])
  128.         hh=self.activation(x[:,2,:,:])
  129.  
  130.         h = z * h_tm1 + (1 - z) * hh
  131.         return h, [h]
  132.  
  133.     def get_constants(self, x):
  134.         constants = []
  135.         #dropping all of this. There us no dropout or anything else in this layer.
  136.         #TODO: do I need to have this method at all. It overrides something from super.
  137.         #might be better to stick with the inherited method if I don't do anything here.
  138.         return constants
  139.  
  140.     def get_initial_states(self, x):
  141.         initial_state=K.zeros_like(x)   # (samples, timesteps, input_dim)
  142.                                         # input_dim = (3, x_dim, y_dim)
  143.         initial_state=K.sum(initial_state, axis=(1,2)) # (samples, x_dim, y_dim)
  144.         return initial_state
  145.  
  146.  
  147.     def get_output_shape_for(self, input_shape):
  148.         #TODO: this is hardcoding for th layout
  149.         return (input_shape[0],1,input_shape[2],input_shape[3])
  150.  
  151.     def get_config(self):
  152.         config = {'output_dim': self.output_dim,
  153.                   'init': self.init.__name__,
  154.                   'inner_init': self.inner_init.__name__,
  155.                   'activation': self.activation.__name__,
  156.                   'inner_activation': self.inner_activation.__name__}
  157.  
  158.         # removed the various regularizers and dropouts.
  159.         # surely this isn't needed if not present ?
  160.         base_config = super(CGRU, self).get_config()
  161.         return dict(list(base_config.items()) + list(config.items()))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement