Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # observed data
- total_rock = tf.constant(5., tf.float32)
- total_paper = tf.constant(0., tf.float32)
- total_scissors = tf.constant(0., tf.float32)
- # define some constants
- number_of_steps = 10000
- burnin = 5000
- # Set the chain's start state
- initial_chain_state = [
- 1/3 * tf.ones([], dtype=tf.float32, name="init_p_rock"),
- 1/3 * tf.ones([], dtype=tf.float32, name="init_p_paper")
- ]
- # for trainsforming contrained parameter space (in this case, [0, 1] for each parameter) to unconstrained real numbers
- unconstraining_bijectors = [
- tfp.bijectors.Sigmoid(), # bijector for p_rock
- tfp.bijectors.Sigmoid() # bijector for p_paper
- ]
- # fix data to the actually observed values in the joint log prob to optimize only over the unkown model parameters,
- # and convert to tensorflow function for speedup
- joint_log_prob_for_opt = tf.function(func=lambda x, y: joint_log_prob(total_rock=total_rock,
- total_paper=total_paper,
- total_scissors=total_scissors,
- p_rock=x, p_paper=y),
- input_signature=2 * (tf.TensorSpec(shape=[], dtype=tf.float32),)
- )
- # define the Hamilton markov chain by successively adding wrappers to an inner kernel
- kernel = tfp.mcmc.HamiltonianMonteCarlo(
- target_log_prob_fn=joint_log_prob_for_opt,
- num_leapfrog_steps=2,
- step_size=tf.constant(0.5, dtype=tf.float32),
- state_gradients_are_stopped=True
- )
- kernel=tfp.mcmc.TransformedTransitionKernel(
- inner_kernel=kernel,
- bijector=unconstraining_bijectors
- )
- kernel = tfp.mcmc.SimpleStepSizeAdaptation(
- inner_kernel=kernel,
- num_adaptation_steps=int(burnin * 0.8)
- )
- # sample from the chain
- [
- posterior_p_rock,
- posterior_p_paper
- ], kernel_results = tfp.mcmc.sample_chain(
- num_results=number_of_steps,
- num_burnin_steps=burnin,
- current_state=initial_chain_state,
- trace_fn=lambda _, kernel_results: kernel_results.inner_results.inner_results.is_accepted,
- kernel=kernel)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement