Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- @tf.keras.utils.register_keras_serializable()
- class DenseSVD(tf.keras.layers.Layer):
- def __init__(self, units, hidden_units, activation='relu', **kwargs):
- (super(DenseSVD, self).__init__)(**kwargs)
- self.units = units
- self.hidden_units = hidden_units
- self.activation = tf.keras.activations.get(activation)
- def build(self, input_shape):
- _, features = input_shape
- w_init = tf.random_normal_initializer()
- self.u = tf.Variable(name='u', initial_value=w_init(shape=[features, self.hidden_units], dtype='float32'))
- self.n = tf.Variable(name='n', initial_value=w_init(shape=[self.hidden_units, self.units], dtype='float32'),
- trainable=True)
- b_init = tf.zeros_initializer()
- self.bias0 = tf.Variable(name='bias0', initial_value=b_init(shape=[self.units],dtype='float32'),
- trainable=True)
- def get_config(self):
- config = super(DenseSVD, self).get_config().copy()
- config.update({'hidden_units': self.hidden_units, 'units':self.units})
- return config
- def call(self, inputs):
- x = tf.matmul(inputs, self.u)
- x = tf.matmul(x, self.n)
- return self.activation(x+self.bias0)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement