Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import tensorflow_probability as tfp
- import tensorflow_probability.python.bijectors as tfb
- import tensorflow_probability.python.distributions as tfd
- def trainable_lu_factorization(event_size,
- trainable=True,
- batch_shape=(),
- seed=None,
- dtype=tf.float32,
- name=None):
- with tf.name_scope('trainable_lu_factorization'):
- event_size = tf.convert_to_tensor(event_size,
- dtype=tf.int32,
- name='event_size')
- batch_shape = tf.convert_to_tensor(batch_shape,
- dtype=event_size.dtype,
- name='batch_shape')
- random_matrix = tf.Variable(tf.random.uniform(
- shape=tf.concat([batch_shape, [event_size, event_size]], axis=0),
- dtype=dtype,
- seed=seed,
- ),
- trainable=trainable,
- name='conv1x1_weights')
- def lu_p(random_matrix):
- return tf.linalg.lu(tf.linalg.qr(random_matrix).q)
- lower_upper = tfp.util.DeferredTensor(lambda m: lu_p(m)[0],
- random_matrix)
- permutation = tfp.util.DeferredTensor(lambda m: lu_p(m)[1],
- random_matrix,
- dtype=tf.int32,
- shape=random_matrix.shape[:-1])
- return lower_upper, permutation
- def build_model():
- channels = 3
- trainable = True
- # conv1x1 setup
- t_lower_upper, t_permutation = trainable_lu_factorization(
- channels, trainable)
- conv1x1 = tfb.MatvecLU(t_lower_upper, t_permutation, name='MatvecLU')
- inv_conv1x1 = tfb.Invert(conv1x1)
- # forward setup
- fwd = tfp.layers.DistributionLambda(
- lambda x: conv1x1(tfd.Deterministic(x)))
- fwd.bijector = conv1x1
- # inverse setup
- inv = tfp.layers.DistributionLambda(
- lambda x: inv_conv1x1(tfd.Deterministic(x)))
- inv.bijector = inv_conv1x1
- x: tf.Tensor = tf.keras.Input(shape=[28, 28, channels])
- fwd_x: tfp.distributions.TransformedDistribution = fwd(x)
- # fwd_x: tf.Tensor = fwd_x.sample()
- rev_fwd_x: tfp.distributions.TransformedDistribution = inv(fwd_x)
- # rev_fwd_x: tf.Tensor = rev_fwd_x.sample()
- example_model = tf.keras.Model(inputs=x, outputs=rev_fwd_x)
- return example_model
- def main():
- print('tensorflow : ', tf.__version__) # 2.0.0-rc0
- print('tensorflow-probability : ', tfp.__version__) # 0.8.0-rc0
- # setup environment
- example_model = build_model()
- example_model.summary()
- real_x = tf.random.uniform(shape=[2, 28, 28, 3], dtype=tf.float32)
- if example_model.weights == []:
- print('No Trainable Variable exists')
- else:
- print('Some Trainable Variables exist')
- with tf.GradientTape() as tape:
- tape.watch(real_x)
- out_x = example_model(real_x)
- out_x = out_x
- loss = out_x - real_x
- print(tf.math.reduce_sum(real_x - out_x))
- # => nealy 0
- # ex. tf.Tensor(1.3522818e-05, shape=(), dtype=float32)
- try:
- print(tape.gradient(loss, real_x))
- except Exception as e:
- print('Cannot Calculate Gradient')
- print(e)
- if __name__ == '__main__':
- main()
- ##########################################################
- # tensorflow : 2.0.0-rc0
- # tensorflow-probability : 0.8.0-rc0
- # Model: "model_35"
- # _________________________________________________________________
- # Layer (type) Output Shape Param #
- # =================================================================
- # input_36 (InputLayer) [(None, 28, 28, 3)] 0
- # _________________________________________________________________
- # distribution_lambda_70 (Dist ((None, 28, 28, 3), (None 0
- # _________________________________________________________________
- # distribution_lambda_71 (Dist ((None, 28, 28, 3), (None 0
- # =================================================================
- # Total params: 0
- # Trainable params: 0
- # Non-trainable params: 0
- # _________________________________________________________________
- # No Trainable Variable exists
- # tf.Tensor(5.712954e-05, shape=(), dtype=float32)
- # Cannot Calculate Gradient
- # gradient registry has no entry for: Lu
- ###########################################################
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement