Advertisement
Guest User

Untitled

a guest
Mar 11th, 2023
666
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.22 KB | None | 0 0
  1. @tf.keras.utils.register_keras_serializable()
  2. class DenseSVD(tf.keras.layers.Layer):
  3.  
  4.     def __init__(self, units, hidden_units, activation='relu', **kwargs):
  5.         (super(DenseSVD, self).__init__)(**kwargs)
  6.         self.units = units
  7.         self.hidden_units = hidden_units
  8.         self.activation = tf.keras.activations.get(activation)
  9.  
  10.     def build(self, input_shape):
  11.         _, features = input_shape
  12.         w_init = tf.random_normal_initializer()
  13.         self.u = tf.Variable(name='u', initial_value=w_init(shape=[features, self.hidden_units], dtype='float32'))
  14.         self.n = tf.Variable(name='n', initial_value=w_init(shape=[self.hidden_units, self.units], dtype='float32'),
  15.           trainable=True)
  16.         b_init = tf.zeros_initializer()
  17.         self.bias0 = tf.Variable(name='bias0', initial_value=b_init(shape=[self.units],dtype='float32'),
  18.           trainable=True)
  19.  
  20.     def get_config(self):
  21.         config = super(DenseSVD, self).get_config().copy()
  22.         config.update({'hidden_units': self.hidden_units, 'units':self.units})
  23.         return config
  24.  
  25.     def call(self, inputs):
  26.         x = tf.matmul(inputs, self.u)
  27.         x = tf.matmul(x, self.n)
  28.         return self.activation(x+self.bias0)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement