Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- !pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-$(pip search jaxlib | grep -oP '[0-9\.]+' | head -n 1)-cp36-none-linux_x86_64.whl
- !pip install --upgrade -q jax
- !pip install -q git+https://www.github.com/google/jax-md
- #import anything needed
- import numpy as onp
- import jax.numpy as np
- from jax import random, jit, vmap, grad, lax#, ops
- from jax_md import space, smap, energy, minimize, quantity, simulate
- from jax.config import config
- config.update("jax_enable_x64", True)
- import matplotlib
- import matplotlib.pyplot as plt
- import seaborn as sns
- sns.set_style(style='white')
- def format_plot(x, y):
- plt.grid(True)
- plt.xlabel(x, fontsize=20)
- plt.ylabel(y, fontsize=20)
- def finalize_plot(shape=(1, 1)):
- plt.gcf().set_size_inches(
- shape[0] * 1.5 * plt.gcf().get_size_inches()[1],
- shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
- plt.tight_layout()
- import time
- key = random.PRNGKey(0)
- def brownian(key, batch_size=1):
- '''
- Arguments:
- key: sets up random number generation
- a: vector containing the temperatures every X timesteps (X depends on the length of a)
- batch_size: batch size
- Returns:
- order_param: energy after minimization
- '''
- #Set up simulation parameters
- N = 100
- box_size = 4.9
- displacement, shift = space.periodic(box_size)
- dimension = 2
- dt = 1e-4
- num_steps = 1000
- T_schedule = lambda t: np.exp(-t)
- #Set a random initial state
- key, split = random.split(key)
- R = random.uniform(split, (N, dimension), minval=0.0, maxval=box_size, dtype=np.float64)
- #Compile the dynamics
- energy_fn = energy.soft_sphere_pair(displacement, epsilon=5.0)
- init, apply = simulate.brownian(energy_fn, shift, dt=dt, T_schedule=T_schedule)
- key, split = random.split(key)
- apply = jit(apply)
- state = init(split, R)
- #Run the brownian dynamics
- for i in range(num_steps):
- t = i * dt
- state = apply(state)
- return energy_fn(state.position), state.position
- eng, R = brownian(key)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement