Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- ValueError Traceback (most recent call last)
- <ipython-input-17-ab5eae0dd51a> in <module>
- 3 ]
- 4
- ----> 5 run_nuts(
- 6 target_log_prob_fn,
- 7 inits[:-1]
- <ipython-input-13-2d7a920574a8> in run_nuts_template(trace_fn, target_log_prob_fn, inits, bijectors_list, num_steps, num_burnin, n_chains)
- 45 )
- 46
- ---> 47 res = tfp.mcmc.sample_chain(
- 48 num_results=num_steps,
- 49 num_burnin_steps=num_burnin,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in sample_chain(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, trace_fn, return_final_kernel_results, parallel_iterations, seed, name)
- 359 return seed, next_state, current_kernel_results
- 360
- --> 361 (_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
- 362 loop_fn=_trace_scan_fn,
- 363 initial_state=(seed, current_state, previous_kernel_results),
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn, static_trace_allocation_size, parallel_iterations, name)
- 462 return i + 1, state, num_steps_traced, trace_arrays
- 463
- --> 464 _, final_state, _, trace_arrays = tf.while_loop(
- 465 cond=lambda i, *_: i < length,
- 466 body=_body,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
- 603 func.__module__, arg_name, arg_value, 'in a future version'
- 604 if date is None else ('after %s' % date), instructions)
- --> 605 return func(*args, **kwargs)
- 606
- 607 doc = _add_deprecated_arg_value_notice_to_docstring(
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
- 2487
- 2488 """
- -> 2489 return while_loop(
- 2490 cond=cond,
- 2491 body=body,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
- 2733 list(loop_vars))
- 2734 while cond(*loop_vars):
- -> 2735 loop_vars = body(*loop_vars)
- 2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
- 2737 packed = True
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in _body(i, state, num_steps_traced, trace_arrays)
- 452 def _body(i, state, num_steps_traced, trace_arrays):
- 453 elem = elems_array.read(i)
- --> 454 state = loop_fn(state, elem)
- 455
- 456 trace_arrays, num_steps_traced = ps.cond(
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _trace_scan_fn(seed_state_and_results, num_steps)
- 352
- 353 def _trace_scan_fn(seed_state_and_results, num_steps):
- --> 354 seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
- 355 loop_num_iter=num_steps,
- 356 body_fn=_seeded_one_step,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in smart_for_loop(loop_num_iter, body_fn, initial_loop_vars, parallel_iterations, unroll_threshold, name)
- 346 # where while/LoopCond needs it.
- 347 loop_num_iter = tf.cast(loop_num_iter, dtype=tf.int32)
- --> 348 return tf.while_loop(
- 349 cond=lambda i, *args: i < loop_num_iter,
- 350 body=lambda i, *args: [i + 1] + list(body_fn(*args)),
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
- 603 func.__module__, arg_name, arg_value, 'in a future version'
- 604 if date is None else ('after %s' % date), instructions)
- --> 605 return func(*args, **kwargs)
- 606
- 607 doc = _add_deprecated_arg_value_notice_to_docstring(
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
- 2487
- 2488 """
- -> 2489 return while_loop(
- 2490 cond=cond,
- 2491 body=body,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
- 2733 list(loop_vars))
- 2734 while cond(*loop_vars):
- -> 2735 loop_vars = body(*loop_vars)
- 2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
- 2737 packed = True
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in <lambda>(i, *args)
- 348 return tf.while_loop(
- 349 cond=lambda i, *args: i < loop_num_iter,
- --> 350 body=lambda i, *args: [i + 1] + list(body_fn(*args)),
- 351 loop_vars=[np.int32(0)] + initial_loop_vars,
- 352 parallel_iterations=parallel_iterations
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _seeded_one_step(seed, *state_and_results)
- 349 one_step_kwargs = dict(seed=step_seed) if is_seeded else {}
- 350 return [passalong_seed] + list(
- --> 351 kernel.one_step(*state_and_results, **one_step_kwargs))
- 352
- 353 def _trace_scan_fn(seed_state_and_results, num_steps):
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py in one_step(self, current_state, previous_kernel_results, seed)
- 454 # Step the inner kernel.
- 455 inner_kwargs = {} if seed is None else dict(seed=seed)
- --> 456 new_state, new_inner_results = self.inner_kernel.one_step(
- 457 current_state, inner_results, **inner_kwargs)
- 458
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py in one_step(self, current_state, previous_kernel_results, seed)
- 397 self.name, 'transformed_kernel', 'one_step')):
- 398 inner_kwargs = {} if seed is None else dict(seed=seed)
- --> 399 transformed_next_state, kernel_results = self._inner_kernel.one_step(
- 400 previous_kernel_results.transformed_state,
- 401 previous_kernel_results.inner_results,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in one_step(self, current_state, previous_kernel_results, seed)
- 392 )
- 393
- --> 394 _, _, _, new_step_metastate = tf.while_loop(
- 395 cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda
- 396 (iter_ < self.max_tree_depth) &
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
- 603 func.__module__, arg_name, arg_value, 'in a future version'
- 604 if date is None else ('after %s' % date), instructions)
- --> 605 return func(*args, **kwargs)
- 606
- 607 doc = _add_deprecated_arg_value_notice_to_docstring(
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
- 2487
- 2488 """
- -> 2489 return while_loop(
- 2490 cond=cond,
- 2491 body=body,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
- 2733 list(loop_vars))
- 2734 while cond(*loop_vars):
- -> 2735 loop_vars = body(*loop_vars)
- 2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
- 2737 packed = True
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, state, metastate)
- 396 (iter_ < self.max_tree_depth) &
- 397 tf.reduce_any(metastate.continue_tree)),
- --> 398 body=lambda iter_, seed, state, metastate: self._loop_tree_doubling( # pylint: disable=g-long-lambda
- 399 previous_kernel_results.step_size,
- 400 previous_kernel_results.momentum_state_memory,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate, seed)
- 570 momentum_subtree_cumsum,
- 571 leapfrogs_taken
- --> 572 ] = self._build_sub_tree(
- 573 directions_expanded,
- 574 integrator,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _build_sub_tree(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, seed, name)
- 750 final_not_divergence,
- 751 momentum_state_memory,
- --> 752 ] = tf.while_loop(
- 753 cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda
- 754 leapfrogs_taken, state, state_c, continue_tree,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
- 603 func.__module__, arg_name, arg_value, 'in a future version'
- 604 if date is None else ('after %s' % date), instructions)
- --> 605 return func(*args, **kwargs)
- 606
- 607 doc = _add_deprecated_arg_value_notice_to_docstring(
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
- 2487
- 2488 """
- -> 2489 return while_loop(
- 2490 cond=cond,
- 2491 body=body,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
- 2733 list(loop_vars))
- 2734 while cond(*loop_vars):
- -> 2735 loop_vars = body(*loop_vars)
- 2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
- 2737 packed = True
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory)
- 758 leapfrogs_taken, state, state_c, continue_tree,
- 759 not_divergence, momentum_state_memory: (
- --> 760 self._loop_build_sub_tree(
- 761 directions, integrator, current_step_meta_info,
- 762 iter_, energy_diff_sum, init_momentum_cumsum,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory, seed)
- 811 next_target,
- 812 next_target_grad_parts
- --> 813 ] = integrator(prev_tree_state.momentum,
- 814 prev_tree_state.state,
- 815 prev_tree_state.target,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in __call__(self, momentum_parts, state_parts, target, target_grad_parts, kinetic_energy_fn, name)
- 295 next_target,
- 296 next_target_grad_parts,
- --> 297 ] = tf.while_loop(
- 298 cond=lambda i, *_: i < self.num_steps,
- 299 body=lambda i, *args: [i + 1] + list(_one_step( # pylint: disable=no-value-for-parameter,g-long-lambda
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
- 603 func.__module__, arg_name, arg_value, 'in a future version'
- 604 if date is None else ('after %s' % date), instructions)
- --> 605 return func(*args, **kwargs)
- 606
- 607 doc = _add_deprecated_arg_value_notice_to_docstring(
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
- 2487
- 2488 """
- -> 2489 return while_loop(
- 2490 cond=cond,
- 2491 body=body,
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
- 2733 list(loop_vars))
- 2734 while cond(*loop_vars):
- -> 2735 loop_vars = body(*loop_vars)
- 2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
- 2737 packed = True
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in <lambda>(i, *args)
- 297 ] = tf.while_loop(
- 298 cond=lambda i, *_: i < self.num_steps,
- --> 299 body=lambda i, *args: [i + 1] + list(_one_step( # pylint: disable=no-value-for-parameter,g-long-lambda
- 300 self.target_fn, self.step_sizes, get_velocity_parts, *args)),
- 301 loop_vars=[
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in _one_step(target_fn, step_sizes, get_velocity_parts, half_next_momentum_parts, state_parts, target, target_grad_parts)
- 353 next_target_grad_parts))
- 354
- --> 355 tensorshape_util.set_shape(next_target, target.shape)
- 356 for ng, g in zip(next_target_grad_parts, target_grad_parts):
- 357 tensorshape_util.set_shape(ng, g.shape)
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py in set_shape(tensor, shape)
- 326 """
- 327 if hasattr(tensor, 'set_shape'):
- --> 328 tensor.set_shape(shape)
- 329
- 330
- ~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in set_shape(self, shape)
- 1213 def set_shape(self, shape):
- 1214 if not self.shape.is_compatible_with(shape):
- -> 1215 raise ValueError(
- 1216 "Tensor's shape %s is not compatible with supplied shape %s" %
- 1217 (self.shape, shape))
- ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement