Advertisement
Guest User

Untitled

a guest
Jun 16th, 2019
50
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.64 KB | None | 0 0
  1. ### ATTEMPT #1:
  2. def pi_ini(shape):
  3. init=np.random.uniform(low=0.0, high=1.0, size=shape)
  4. return init
  5.  
  6. def spike_ini(shape, pi, dtype=np.float32):
  7. init = np.random.binomial(n=1,p=pi,size=shape).astype(dtype=dtype)
  8. return init
  9.  
  10. class PenultimateLayer(Layer):
  11. def __init__(self, units, pi_initializer, spike_initializer,
  12. activation=None,
  13. use_bias=True,
  14. kernel_initializer='he_normal',
  15. bias_initializer='he_normal',
  16. kernel_regularizer=None,
  17. bias_regularizer=None,
  18. activity_regularizer=None,
  19. kernel_constraint=None,
  20. bias_constraint=None,
  21. **kwargs):
  22. if 'input_shape' not in kwargs and 'input_dim' in kwargs:
  23. kwargs['input_shape'] = (kwargs.pop('input_dim'),)
  24. super(PenultimateLayer, self).__init__(**kwargs)
  25.  
  26.  
  27. self.units = units
  28. self.activation = activations.get(activation)
  29. self.use_bias = use_bias
  30.  
  31. self.pi_initializer=pi_initializer
  32. self.spike_initializer=spike_initializer
  33. self.kernel_initializer = initializers.get(kernel_initializer)
  34. self.bias_initializer = initializers.get(bias_initializer)
  35. self.kernel_regularizer = regularizers.get(kernel_regularizer)
  36. self.bias_regularizer = regularizers.get(bias_regularizer)
  37. self.activity_regularizer = regularizers.get(activity_regularizer)
  38. self.kernel_constraint = constraints.get(kernel_constraint)
  39. self.bias_constraint = constraints.get(bias_constraint)
  40. self.input_spec = InputSpec(min_ndim=2)
  41. self.supports_masking = True
  42.  
  43. def build(self, input_shape):
  44. assert len(input_shape) >= 2
  45. input_dim = input_shape[-1]
  46.  
  47. self.kernel = self.add_weight(shape=(input_dim, self.units),
  48. initializer=self.kernel_initializer,
  49. name='kernel',
  50. regularizer=self.kernel_regularizer,
  51. constraint=self.kernel_constraint)
  52.  
  53.  
  54. self.pi=self.add_weight(shape=(input_dim, self.units),
  55. initializer=self.pi_initializer,
  56. name='pi',
  57. regularizer=None,
  58. constraint=None
  59. )
  60.  
  61. self.spike=self.add_weight(shape=(input_dim, self.units),
  62. initializer=self.spike_initializer,
  63. name='spike',
  64. regularizer=None,
  65. constraint=None)
  66.  
  67. if self.use_bias:
  68. self.bias = self.add_weight(shape=(self.units,),
  69. initializer=self.bias_initializer,
  70. name='bias',
  71. regularizer=self.bias_regularizer,
  72. constraint=self.bias_constraint)
  73. else:
  74. self.bias = None
  75.  
  76. self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
  77. self.built = True
  78.  
  79. def call(self, inputs):
  80. output = K.dot(inputs, (self.spike*self.kernel))
  81. if self.use_bias:
  82. output = K.bias_add(output, self.bias, data_format='channels_last')
  83. if self.activation is not None:
  84. output = self.activation(output)
  85. return output
  86.  
  87. def compute_output_shape(self, input_shape):
  88. assert input_shape and len(input_shape) >= 2
  89. assert input_shape[-1]
  90. output_shape = list(input_shape)
  91. output_shape[-1] = self.units
  92. return tuple(output_shape)
  93.  
  94.  
  95. ### ATTEMPT #2:
  96. class PenultimateLayer(Layer):
  97. def __init__(self, units, pi_initializer,
  98. activation=None,
  99. use_bias=True,
  100. kernel_initializer='he_normal',
  101. bias_initializer='he_normal',
  102. kernel_regularizer=None,
  103. bias_regularizer=None,
  104. activity_regularizer=None,
  105. kernel_constraint=None,
  106. bias_constraint=None,
  107. **kwargs):
  108. if 'input_shape' not in kwargs and 'input_dim' in kwargs:
  109. kwargs['input_shape'] = (kwargs.pop('input_dim'),)
  110. super(PenultimateLayer, self).__init__(**kwargs)
  111.  
  112.  
  113. self.units = units
  114. self.activation = activations.get(activation)
  115. self.use_bias = use_bias
  116.  
  117. # def spikeVar_initializer(shape, pi, name=None):
  118. # sp_vars=np.random.binomial(n=1, p=pi, size=shape)
  119. # return K.variable(sp_vars.astype(dtype=np.float32), name=name)
  120.  
  121. self.pi_initializer=spike_initializer
  122. self.kernel_initializer = initializers.get(kernel_initializer)
  123. self.bias_initializer = initializers.get(bias_initializer)
  124. self.kernel_regularizer = regularizers.get(kernel_regularizer)
  125. self.bias_regularizer = regularizers.get(bias_regularizer)
  126. self.activity_regularizer = regularizers.get(activity_regularizer)
  127. self.kernel_constraint = constraints.get(kernel_constraint)
  128. self.bias_constraint = constraints.get(bias_constraint)
  129. self.input_spec = InputSpec(min_ndim=2)
  130. self.supports_masking = True
  131.  
  132. def build(self, input_shape):
  133. assert len(input_shape) >= 2
  134. input_dim = input_shape[-1]
  135.  
  136. self.kernel = self.add_weight(shape=(input_dim, self.units),
  137. initializer=self.kernel_initializer,
  138. name='kernel',
  139. regularizer=self.kernel_regularizer,
  140. constraint=self.kernel_constraint)
  141.  
  142. self.pi=self.add_weight(shape=(input_dim, self.units),
  143. initializer=self.pi_initializer,
  144. name='pi',
  145. regularizer=None,
  146. constraint=None
  147. )
  148.  
  149. if self.use_bias:
  150. self.bias = self.add_weight(shape=(self.units,),
  151. initializer=self.bias_initializer,
  152. name='bias',
  153. regularizer=self.bias_regularizer,
  154. constraint=self.bias_constraint)
  155. else:
  156. self.bias = None
  157.  
  158. self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
  159. self.built = True
  160.  
  161. def call(self, inputs):
  162. spike=tfp.distributions.Bernoulli(logits=None,probs=self.pi, dtype=tf.dtypes.float32).sample()
  163.  
  164. output = K.dot(inputs, (spike*self.kernel))
  165. if self.use_bias:
  166. output = K.bias_add(output, self.bias, data_format='channels_last')
  167. if self.activation is not None:
  168. output = self.activation(output)
  169. return output
  170.  
  171. def compute_output_shape(self, input_shape):
  172. assert input_shape and len(input_shape) >= 2
  173. assert input_shape[-1]
  174. output_shape = list(input_shape)
  175. output_shape[-1] = self.units
  176. return tuple(output_shape)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement