Advertisement
Guest User

Untitled

a guest
May 24th, 2019
142
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.50 KB | None | 0 0
  1. def run(D0):
  2. dt=0.001
  3. kBT=0.1
  4. energy_fn = morse_pairwise(displacement, D0=D0)
  5. init, apply = simulate.nvt_nose_hoover(energy_fn, shift, dt, kBT)
  6. apply = jit(apply)
  7. state = init(key, Rinit)
  8. T = []
  9. H = []
  10. E = []
  11.  
  12. print_every = 50
  13. old_time = time.clock()
  14. print('Step\tT_goal\tT\ttime/step\tclassification')
  15. print('---------------------------------------------------------')
  16.  
  17. for i in range(100):
  18. t = i * dt
  19. state = apply(state, t=t)
  20. T += [quantity.temperature(state.velocity)]
  21. H += [invariant(kBT, state,energy_fn)]
  22. E += [energy_fn(state.position)]
  23. state = apply(state, t=t)
  24.  
  25. if i % print_every == 0:# and i > 0:
  26. #save file
  27. if i==0:
  28. command = 'w'
  29. else:
  30. command = 'a'
  31. SaveState("testOut.xyz",state.position,str(i),command)
  32.  
  33. #new_time = time.clock()
  34. #print('{}\t{:.2f}\t{:.2f}\t{:.3f}\t\t{}'.format(i, kBT, T[-1], (new_time - old_time) / print_every, classify(state.position)))
  35. #old_time = new_time
  36.  
  37. T = np.array(T)
  38. H = np.array(H)
  39. R = state.position
  40. #return state
  41. #return np.sum(space.distance(displacement(state.position,state.position)))
  42. return energy_fn(state.position)
  43.  
  44. grad(run)(D0)
  45.  
  46. ###################
  47. ##### output:######
  48. ###################
  49.  
  50. Step T_goal T time/step classification
  51. ---------------------------------------------------------
  52. ---------------------------------------------------------------------------
  53. AttributeError Traceback (most recent call last)
  54. <ipython-input-14-d272be2c265c> in <module>()
  55. ----> 1 grad(run)(D0)
  56.  
  57. 13 frames
  58. /usr/local/lib/python3.6/dist-packages/jax/api.py in grad_f(*args, **kwargs)
  59. 233 def grad_f(*args, **kwargs):
  60. 234 if not has_aux:
  61. --> 235 _, g = value_and_grad_f(*args, **kwargs)
  62. 236 return g
  63. 237 else:
  64.  
  65. /usr/local/lib/python3.6/dist-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
  66. 287 "differentiation, pass holomorphic=True.")
  67. 288 raise TypeError(msg.format(dtype))
  68. --> 289 g = vjp_py(onp.ones((), dtype=dtype))
  69. 290 g = g[0] if isinstance(argnums, int) else g
  70. 291 if not has_aux:
  71.  
  72. /usr/local/lib/python3.6/dist-packages/jax/api_util.py in apply_jaxtree_fun(fun, io_tree, *py_args)
  73. 60 raise TypeError("Expected {}, got {}".format(expected, in_tree))
  74. 61
  75. ---> 62 ans = fun(*args)
  76. 63 return build_tree(out_tree, ans)
  77. 64
  78.  
  79. /usr/local/lib/python3.6/dist-packages/jax/api.py in out_vjp_packed(cotangent_in)
  80. 820 ct_out_tree = PyTreeDef(node_types[tuple], None, in_trees)
  81. 821 def out_vjp_packed(cotangent_in):
  82. --> 822 return out_vjp(cotangent_in)
  83. 823 vjp_py = partial(apply_jaxtree_fun, out_vjp_packed, (ct_in_trees, ct_out_tree))
  84. 824 if not has_aux:
  85.  
  86. /usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in vjp_(ct)
  87. 110 dummy_primal_and_ct = pack((core.unit, ct))
  88. 111 dummy_args = (None,) * len(jaxpr.invars)
  89. --> 112 _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primal_and_ct)
  90. 113 return instantiate_zeros(pack(primals), arg_cts[1])
  91. 114
  92.  
  93. /usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in)
  94. 178 for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs])
  95. 179 cts_out, ct_free_vars_out = get_primitive_transpose(eqn.primitive)(
  96. --> 180 eqn.params, subjaxprs, sub_consts, sub_freevar_vals, invals, ct_in)
  97. 181 # TODO(dougalm): support cases != 1
  98. 182 assert(len(eqn.bound_subjaxprs) == 1)
  99.  
  100. /usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct)
  101. 534 all_args = pack((pack(args), pack(consts), pack(freevar_vals), ct))
  102. 535 # TODO(dougalm): consider signalling to bind that no traces in fun closure
  103. --> 536 ans = primitive.bind(fun, all_args, **params)
  104. 537 return build_tree(out_tree_def(), ans)
  105. 538
  106.  
  107. /usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
  108. 634 if top_trace is None:
  109. 635 with new_sublevel():
  110. --> 636 ans = primitive.impl(f, *args, **params)
  111. 637 else:
  112. 638 tracers = map(top_trace.full_raise, args)
  113.  
  114. /usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
  115. 589 def xla_call_impl(fun, *args, **params):
  116. 590 device_values = FLAGS.jax_device_values and params.pop('device_values')
  117. --> 591 compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
  118. 592 try:
  119. 593 return compiled_fun(*args)
  120.  
  121. /usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(f, *args)
  122. 206 if len(cache) > max_size:
  123. 207 cache.popitem(last=False)
  124. --> 208 ans = call(f, *args)
  125. 209 cache[key] = (ans, f)
  126. 210 return ans
  127.  
  128. /usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
  129. 602 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
  130. 603 with core.new_master(pe.JaxprTrace, True) as master:
  131. --> 604 jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
  132. 605 assert not env # no subtraces here (though cond might eventually need them)
  133. 606 compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
  134.  
  135. /usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
  136. 145
  137. 146 del gen
  138. --> 147 ans = self.f(*args, **dict(self.params, **kwargs))
  139. 148 del args
  140. 149 while stack:
  141.  
  142. /usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in)
  143. 184 map(write_cotangent, bound_vars, ct_free_vars_out)
  144. 185 else:
  145. --> 186 cts_out = get_primitive_transpose(eqn.primitive)(ct_in, *invals, **eqn.params)
  146. 187
  147. 188 if cts_out is zero:
  148.  
  149. /usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in _scatter_add_transpose_rule(t, operand, scatter_indices, updates, update_jaxpr, update_consts, dimension_numbers, updates_shape)
  150. 2816 slice_sizes = []
  151. 2817 pos = 0
  152. -> 2818 for i in xrange(len(t.shape)):
  153. 2819 if i in dimension_numbers.inserted_window_dims:
  154. 2820 slice_sizes.append(1)
  155.  
  156. AttributeError: 'Zero' object has no attribute 'shape'
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement