Advertisement
Guest User

Untitled

a guest
Dec 8th, 2023
142
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.17 KB | None | 0 0
  1. class MultiWeight(Layer):
  2.  
  3.     def __init__(self, units, num_sets):
  4.         super(MultiWeight, self).__init__()
  5.         self.units = units
  6.         self.num_sets = num_sets
  7.  
  8.     def build(self, input_shape):  # Create the state of the layer (weights)
  9.         w_init = tf.random_normal_initializer()
  10.         b_init = tf.zeros_initializer()
  11.  
  12.         self.w = tf.Variable(
  13.             initial_value=w_init(shape=(input_shape[-1], self.units * self.num_sets),
  14.                                  dtype='float32'),
  15.             trainable=True)
  16.  
  17.         self.b = tf.Variable(
  18.             initial_value=b_init(shape=(self.units * self.num_sets,), dtype='float32'),
  19.             trainable=True)
  20.  
  21.     def call(self, inputs):  # Defines the computation from inputs to outputs
  22.         ts = []
  23.         for x in range(self.num_sets):
  24.             s_ind = x * self.units
  25.             e_ind = (x + 1) * self.units
  26.             if inputs.shape[0] is None:
  27.                 ts.append(tf.matmul([inputs[0]], self.w[:, s_ind:e_ind]) + self.b[s_ind:e_ind])
  28.             else:
  29.                 ts.append(tf.matmul([inputs[x]], self.w[:, s_ind:e_ind]) + self.b[s_ind:e_ind])
  30.         return tf.concat(ts, 0)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement