Advertisement
Guest User

Untitled

a guest
Aug 20th, 2019
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.26 KB | None | 0 0
  1. # observed data
  2. total_rock = tf.constant(5., tf.float32)
  3. total_paper = tf.constant(0., tf.float32)
  4. total_scissors = tf.constant(0., tf.float32)
  5.  
  6. # define some constants
  7. number_of_steps = 10000
  8. burnin = 5000
  9.  
  10. # Set the chain's start state
  11. initial_chain_state = [
  12. 1/3 * tf.ones([], dtype=tf.float32, name="init_p_rock"),
  13. 1/3 * tf.ones([], dtype=tf.float32, name="init_p_paper")
  14. ]
  15.  
  16. # for trainsforming contrained parameter space (in this case, [0, 1] for each parameter) to unconstrained real numbers
  17. unconstraining_bijectors = [
  18. tfp.bijectors.Sigmoid(), # bijector for p_rock
  19. tfp.bijectors.Sigmoid() # bijector for p_paper
  20. ]
  21.  
  22. # fix data to the actually observed values in the joint log prob to optimize only over the unkown model parameters,
  23. # and convert to tensorflow function for speedup
  24. joint_log_prob_for_opt = tf.function(func=lambda x, y: joint_log_prob(total_rock=total_rock,
  25. total_paper=total_paper,
  26. total_scissors=total_scissors,
  27. p_rock=x, p_paper=y),
  28. input_signature=2 * (tf.TensorSpec(shape=[], dtype=tf.float32),)
  29. )
  30.  
  31. # define the Hamilton markov chain by successively adding wrappers to an inner kernel
  32. kernel = tfp.mcmc.HamiltonianMonteCarlo(
  33. target_log_prob_fn=joint_log_prob_for_opt,
  34. num_leapfrog_steps=2,
  35. step_size=tf.constant(0.5, dtype=tf.float32),
  36. state_gradients_are_stopped=True
  37. )
  38. kernel=tfp.mcmc.TransformedTransitionKernel(
  39. inner_kernel=kernel,
  40. bijector=unconstraining_bijectors
  41. )
  42. kernel = tfp.mcmc.SimpleStepSizeAdaptation(
  43. inner_kernel=kernel,
  44. num_adaptation_steps=int(burnin * 0.8)
  45. )
  46.  
  47. # sample from the chain
  48. [
  49. posterior_p_rock,
  50. posterior_p_paper
  51. ], kernel_results = tfp.mcmc.sample_chain(
  52. num_results=number_of_steps,
  53. num_burnin_steps=burnin,
  54. current_state=initial_chain_state,
  55. trace_fn=lambda _, kernel_results: kernel_results.inner_results.inner_results.is_accepted,
  56. kernel=kernel)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement