Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class MultiWeight(Layer):
- def __init__(self, units, num_sets):
- super(MultiWeight, self).__init__()
- self.units = units
- self.num_sets = num_sets
- def build(self, input_shape): # Create the state of the layer (weights)
- w_init = tf.random_normal_initializer()
- b_init = tf.zeros_initializer()
- self.w = tf.Variable(
- initial_value=w_init(shape=(input_shape[-1], self.units * self.num_sets),
- dtype='float32'),
- trainable=True)
- self.b = tf.Variable(
- initial_value=b_init(shape=(self.units * self.num_sets,), dtype='float32'),
- trainable=True)
- def call(self, inputs): # Defines the computation from inputs to outputs
- ts = []
- for x in range(self.num_sets):
- s_ind = x * self.units
- e_ind = (x + 1) * self.units
- if inputs.shape[0] is None:
- ts.append(tf.matmul([inputs[0]], self.w[:, s_ind:e_ind]) + self.b[s_ind:e_ind])
- else:
- ts.append(tf.matmul([inputs[x]], self.w[:, s_ind:e_ind]) + self.b[s_ind:e_ind])
- return tf.concat(ts, 0)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement