Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def run(D0):
- dt=0.001
- kBT=0.1
- energy_fn = morse_pairwise(displacement, D0=D0)
- init, apply = simulate.nvt_nose_hoover(energy_fn, shift, dt, kBT)
- apply = jit(apply)
- state = init(key, Rinit)
- T = []
- H = []
- E = []
- print_every = 50
- old_time = time.clock()
- print('Step\tT_goal\tT\ttime/step\tclassification')
- print('---------------------------------------------------------')
- for i in range(100):
- t = i * dt
- state = apply(state, t=t)
- T += [quantity.temperature(state.velocity)]
- H += [invariant(kBT, state,energy_fn)]
- E += [energy_fn(state.position)]
- state = apply(state, t=t)
- if i % print_every == 0:# and i > 0:
- #save file
- if i==0:
- command = 'w'
- else:
- command = 'a'
- SaveState("testOut.xyz",state.position,str(i),command)
- #new_time = time.clock()
- #print('{}\t{:.2f}\t{:.2f}\t{:.3f}\t\t{}'.format(i, kBT, T[-1], (new_time - old_time) / print_every, classify(state.position)))
- #old_time = new_time
- T = np.array(T)
- H = np.array(H)
- R = state.position
- #return state
- #return np.sum(space.distance(displacement(state.position,state.position)))
- return energy_fn(state.position)
- grad(run)(D0)
- ###################
- ##### output:######
- ###################
- Step T_goal T time/step classification
- ---------------------------------------------------------
- ---------------------------------------------------------------------------
- AttributeError Traceback (most recent call last)
- <ipython-input-14-d272be2c265c> in <module>()
- ----> 1 grad(run)(D0)
- 13 frames
- /usr/local/lib/python3.6/dist-packages/jax/api.py in grad_f(*args, **kwargs)
- 233 def grad_f(*args, **kwargs):
- 234 if not has_aux:
- --> 235 _, g = value_and_grad_f(*args, **kwargs)
- 236 return g
- 237 else:
- /usr/local/lib/python3.6/dist-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
- 287 "differentiation, pass holomorphic=True.")
- 288 raise TypeError(msg.format(dtype))
- --> 289 g = vjp_py(onp.ones((), dtype=dtype))
- 290 g = g[0] if isinstance(argnums, int) else g
- 291 if not has_aux:
- /usr/local/lib/python3.6/dist-packages/jax/api_util.py in apply_jaxtree_fun(fun, io_tree, *py_args)
- 60 raise TypeError("Expected {}, got {}".format(expected, in_tree))
- 61
- ---> 62 ans = fun(*args)
- 63 return build_tree(out_tree, ans)
- 64
- /usr/local/lib/python3.6/dist-packages/jax/api.py in out_vjp_packed(cotangent_in)
- 820 ct_out_tree = PyTreeDef(node_types[tuple], None, in_trees)
- 821 def out_vjp_packed(cotangent_in):
- --> 822 return out_vjp(cotangent_in)
- 823 vjp_py = partial(apply_jaxtree_fun, out_vjp_packed, (ct_in_trees, ct_out_tree))
- 824 if not has_aux:
- /usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in vjp_(ct)
- 110 dummy_primal_and_ct = pack((core.unit, ct))
- 111 dummy_args = (None,) * len(jaxpr.invars)
- --> 112 _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primal_and_ct)
- 113 return instantiate_zeros(pack(primals), arg_cts[1])
- 114
- /usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in)
- 178 for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs])
- 179 cts_out, ct_free_vars_out = get_primitive_transpose(eqn.primitive)(
- --> 180 eqn.params, subjaxprs, sub_consts, sub_freevar_vals, invals, ct_in)
- 181 # TODO(dougalm): support cases != 1
- 182 assert(len(eqn.bound_subjaxprs) == 1)
- /usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct)
- 534 all_args = pack((pack(args), pack(consts), pack(freevar_vals), ct))
- 535 # TODO(dougalm): consider signalling to bind that no traces in fun closure
- --> 536 ans = primitive.bind(fun, all_args, **params)
- 537 return build_tree(out_tree_def(), ans)
- 538
- /usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
- 634 if top_trace is None:
- 635 with new_sublevel():
- --> 636 ans = primitive.impl(f, *args, **params)
- 637 else:
- 638 tracers = map(top_trace.full_raise, args)
- /usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
- 589 def xla_call_impl(fun, *args, **params):
- 590 device_values = FLAGS.jax_device_values and params.pop('device_values')
- --> 591 compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
- 592 try:
- 593 return compiled_fun(*args)
- /usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(f, *args)
- 206 if len(cache) > max_size:
- 207 cache.popitem(last=False)
- --> 208 ans = call(f, *args)
- 209 cache[key] = (ans, f)
- 210 return ans
- /usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
- 602 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
- 603 with core.new_master(pe.JaxprTrace, True) as master:
- --> 604 jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
- 605 assert not env # no subtraces here (though cond might eventually need them)
- 606 compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
- /usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
- 145
- 146 del gen
- --> 147 ans = self.f(*args, **dict(self.params, **kwargs))
- 148 del args
- 149 while stack:
- /usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in)
- 184 map(write_cotangent, bound_vars, ct_free_vars_out)
- 185 else:
- --> 186 cts_out = get_primitive_transpose(eqn.primitive)(ct_in, *invals, **eqn.params)
- 187
- 188 if cts_out is zero:
- /usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in _scatter_add_transpose_rule(t, operand, scatter_indices, updates, update_jaxpr, update_consts, dimension_numbers, updates_shape)
- 2816 slice_sizes = []
- 2817 pos = 0
- -> 2818 for i in xrange(len(t.shape)):
- 2819 if i in dimension_numbers.inserted_window_dims:
- 2820 slice_sizes.append(1)
- AttributeError: 'Zero' object has no attribute 'shape'
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement