Advertisement
Guest User

Untitled

a guest
Aug 27th, 2019
241
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.08 KB | None | 0 0
  1. !pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-$(pip search jaxlib | grep -oP '[0-9\.]+' | head -n 1)-cp36-none-linux_x86_64.whl
  2. !pip install --upgrade -q jax
  3. !pip install -q git+https://www.github.com/google/jax-md
  4.  
  5. #import anything needed
  6. import numpy as onp
  7. import jax.numpy as np
  8. from jax import random, jit, vmap, grad, lax#, ops
  9. from jax_md import space, smap, energy, minimize, quantity, simulate
  10. from jax.config import config
  11. config.update("jax_enable_x64", True)
  12. import matplotlib
  13. import matplotlib.pyplot as plt
  14. import seaborn as sns
  15. sns.set_style(style='white')
  16. def format_plot(x, y):
  17. plt.grid(True)
  18. plt.xlabel(x, fontsize=20)
  19. plt.ylabel(y, fontsize=20)
  20. def finalize_plot(shape=(1, 1)):
  21. plt.gcf().set_size_inches(
  22. shape[0] * 1.5 * plt.gcf().get_size_inches()[1],
  23. shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  24. plt.tight_layout()
  25. import time
  26.  
  27.  
  28. key = random.PRNGKey(0)
  29.  
  30.  
  31.  
  32.  
  33.  
  34. def brownian(key, batch_size=1):
  35. '''
  36. Arguments:
  37. key: sets up random number generation
  38. a: vector containing the temperatures every X timesteps (X depends on the length of a)
  39. batch_size: batch size
  40. Returns:
  41. order_param: energy after minimization
  42. '''
  43.  
  44. #Set up simulation parameters
  45. N = 100
  46. box_size = 4.9
  47. displacement, shift = space.periodic(box_size)
  48. dimension = 2
  49. dt = 1e-4
  50. num_steps = 1000
  51. T_schedule = lambda t: np.exp(-t)
  52.  
  53. #Set a random initial state
  54. key, split = random.split(key)
  55. R = random.uniform(split, (N, dimension), minval=0.0, maxval=box_size, dtype=np.float64)
  56.  
  57.  
  58. #Compile the dynamics
  59. energy_fn = energy.soft_sphere_pair(displacement, epsilon=5.0)
  60. init, apply = simulate.brownian(energy_fn, shift, dt=dt, T_schedule=T_schedule)
  61.  
  62. key, split = random.split(key)
  63. apply = jit(apply)
  64. state = init(split, R)
  65.  
  66.  
  67. #Run the brownian dynamics
  68. for i in range(num_steps):
  69. t = i * dt
  70. state = apply(state)
  71.  
  72. return energy_fn(state.position), state.position
  73.  
  74.  
  75.  
  76.  
  77. eng, R = brownian(key)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement