Advertisement
Guest User

Untitled

a guest
Sep 17th, 2019
131
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.50 KB | None | 0 0
  1. import tensorflow as tf
  2. import tensorflow_probability as tfp
  3. import tensorflow_probability.python.bijectors as tfb
  4. import tensorflow_probability.python.distributions as tfd
  5.  
  6.  
  7. def trainable_lu_factorization(event_size,
  8. trainable=True,
  9. batch_shape=(),
  10. seed=None,
  11. dtype=tf.float32,
  12. name=None):
  13. with tf.name_scope('trainable_lu_factorization'):
  14. event_size = tf.convert_to_tensor(event_size,
  15. dtype=tf.int32,
  16. name='event_size')
  17. batch_shape = tf.convert_to_tensor(batch_shape,
  18. dtype=event_size.dtype,
  19. name='batch_shape')
  20. random_matrix = tf.Variable(tf.random.uniform(
  21. shape=tf.concat([batch_shape, [event_size, event_size]], axis=0),
  22. dtype=dtype,
  23. seed=seed,
  24. ),
  25. trainable=trainable,
  26. name='conv1x1_weights')
  27.  
  28. def lu_p(random_matrix):
  29. return tf.linalg.lu(tf.linalg.qr(random_matrix).q)
  30.  
  31. lower_upper = tfp.util.DeferredTensor(lambda m: lu_p(m)[0],
  32. random_matrix)
  33. permutation = tfp.util.DeferredTensor(lambda m: lu_p(m)[1],
  34. random_matrix,
  35. dtype=tf.int32,
  36. shape=random_matrix.shape[:-1])
  37. return lower_upper, permutation
  38.  
  39.  
  40. def build_model():
  41. channels = 3
  42. trainable = True
  43. # conv1x1 setup
  44. t_lower_upper, t_permutation = trainable_lu_factorization(
  45. channels, trainable)
  46. conv1x1 = tfb.MatvecLU(t_lower_upper, t_permutation, name='MatvecLU')
  47. inv_conv1x1 = tfb.Invert(conv1x1)
  48.  
  49. # forward setup
  50. fwd = tfp.layers.DistributionLambda(
  51. lambda x: conv1x1(tfd.Deterministic(x)))
  52. fwd.bijector = conv1x1
  53.  
  54. # inverse setup
  55. inv = tfp.layers.DistributionLambda(
  56. lambda x: inv_conv1x1(tfd.Deterministic(x)))
  57. inv.bijector = inv_conv1x1
  58.  
  59. x: tf.Tensor = tf.keras.Input(shape=[28, 28, channels])
  60.  
  61. fwd_x: tfp.distributions.TransformedDistribution = fwd(x)
  62. # fwd_x: tf.Tensor = fwd_x.sample()
  63.  
  64. rev_fwd_x: tfp.distributions.TransformedDistribution = inv(fwd_x)
  65. # rev_fwd_x: tf.Tensor = rev_fwd_x.sample()
  66.  
  67. example_model = tf.keras.Model(inputs=x, outputs=rev_fwd_x)
  68.  
  69. return example_model
  70.  
  71.  
  72. def main():
  73. print('tensorflow : ', tf.__version__) # 2.0.0-rc0
  74. print('tensorflow-probability : ', tfp.__version__) # 0.8.0-rc0
  75. # setup environment
  76.  
  77. example_model = build_model()
  78. example_model.summary()
  79.  
  80. real_x = tf.random.uniform(shape=[2, 28, 28, 3], dtype=tf.float32)
  81. if example_model.weights == []:
  82. print('No Trainable Variable exists')
  83. else:
  84. print('Some Trainable Variables exist')
  85.  
  86. with tf.GradientTape() as tape:
  87. tape.watch(real_x)
  88. out_x = example_model(real_x)
  89. out_x = out_x
  90. loss = out_x - real_x
  91. print(tf.math.reduce_sum(real_x - out_x))
  92. # => nealy 0
  93. # ex. tf.Tensor(1.3522818e-05, shape=(), dtype=float32)
  94.  
  95. try:
  96. print(tape.gradient(loss, real_x))
  97. except Exception as e:
  98. print('Cannot Calculate Gradient')
  99. print(e)
  100.  
  101. if __name__ == '__main__':
  102. main()
  103.  
  104. ##########################################################
  105. # tensorflow : 2.0.0-rc0
  106. # tensorflow-probability : 0.8.0-rc0
  107. # Model: "model_35"
  108. # _________________________________________________________________
  109. # Layer (type) Output Shape Param #
  110. # =================================================================
  111. # input_36 (InputLayer) [(None, 28, 28, 3)] 0
  112. # _________________________________________________________________
  113. # distribution_lambda_70 (Dist ((None, 28, 28, 3), (None 0
  114. # _________________________________________________________________
  115. # distribution_lambda_71 (Dist ((None, 28, 28, 3), (None 0
  116. # =================================================================
  117. # Total params: 0
  118. # Trainable params: 0
  119. # Non-trainable params: 0
  120. # _________________________________________________________________
  121. # No Trainable Variable exists
  122. # tf.Tensor(5.712954e-05, shape=(), dtype=float32)
  123. # Cannot Calculate Gradient
  124. # gradient registry has no entry for: Lu
  125. ###########################################################
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement