Advertisement
iSach

Untitled

Oct 25th, 2023
29
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.96 KB | None | 0 0
  1. import torch
  2. from ms_training import load_trajs # assuming the model is defined in ms_training.py
  3. import numpy as np
  4. import json
  5. from neuralop.models import FNO, TFNO
  6. import seaborn
  7. from neuralop.utils import count_params
  8. import os
  9. import cv2
  10.  
  11. MODELS = {
  12. "FNO": FNO,
  13. "TFNO": TFNO,
  14. }
  15.  
  16. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  17.  
  18. def vorticity2rgb(
  19. w,
  20. vmin = -1.25,
  21. vmax = 1.25,
  22. ):
  23. w = (w - vmin) / (vmax - vmin)
  24. w = 2 * w - 1
  25. w = np.sign(w) * np.abs(w) ** 0.8
  26. w = (w + 1) / 2
  27. w = seaborn.cm.icefire(w)
  28. w = 256 * w[..., :3]
  29. w = w.astype(np.uint8)
  30.  
  31. return w
  32.  
  33. def load_model(dir):
  34. with open(dir + "/model_config.json", "r") as f:
  35. model_config = json.load(f)
  36.  
  37. print(model_config)
  38.  
  39. model_config['model'] = MODELS[model_config['model']]
  40. model = model_config['model'](**model_config['params'])
  41. model.load_state_dict(torch.load(dir + "/model.pt", map_location=torch.device('cpu')))
  42. model.eval()
  43.  
  44. return model
  45.  
  46.  
  47. def sim_traj(model, gt_traj):
  48. """
  49. Traj: [T, H, W]
  50.  
  51. Output: pred_traj [T, H, W]
  52. """
  53.  
  54. pred_traj = torch.zeros_like(gt_traj).to(gt_traj.device)
  55. pred_traj[0] = gt_traj[0]
  56. with torch.no_grad():
  57. for i in range(1, gt_traj.shape[0]):
  58. w = pred_traj[i-1].unsqueeze(0).unsqueeze(0)
  59. pred_traj[i] = model(w)
  60.  
  61. return pred_traj
  62.  
  63.  
  64. def sim_and_render(folder_name, model, gt_traj):
  65. pred_traj = sim_traj(model, gt_traj)
  66.  
  67. os.makedirs(f"renders/{folder_name}", exist_ok=True)
  68.  
  69. def write_traj(traj, name, fps=10):
  70. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  71. out = cv2.VideoWriter(f'renders/{folder_name}/{name}.mp4', fourcc, 10, (traj.shape[2], traj.shape[1]))
  72. for frame in traj:
  73. if traj.shape[1] != traj.shape[2]:
  74. frame = cv2.putText(frame, 'G', (1, 8), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1, cv2.LINE_AA)
  75. frame = cv2.putText(frame, 'P', (traj.shape[2] // 2 + 1, 8), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1, cv2.LINE_AA)
  76. out.write(frame)
  77.  
  78. out.release()
  79.  
  80. pred_rgb, gt_rgb = vorticity2rgb(pred_traj), vorticity2rgb(gt_traj)
  81. cat_rgb = np.concatenate([gt_rgb, pred_rgb], axis=2)
  82. diff_traj = np.abs(gt_traj - pred_traj)
  83. diff_rgb = (seaborn.cm.rocket(diff_traj)[..., :3] * 256).astype(np.uint8)
  84.  
  85. write_traj(pred_rgb, "pred")
  86. write_traj(gt_rgb, "gt")
  87. write_traj(cat_rgb, "both")
  88. write_traj(diff_rgb, "diff")
  89.  
  90.  
  91. data_name = "navier_stokes"
  92. data = np.load("../datasets/data/ns_vorticity_1e-3.npy") # [50, 64, 64, B=500]
  93. data = data[:41, ..., 100:] # [41, 64, 64, 100]
  94. data = np.transpose(data, (3, 0, 1, 2)) # [100, 41, 64, 64]
  95. data = torch.tensor(data, dtype=torch.float32)
  96.  
  97. print(data.shape)
  98.  
  99. model = load_model("models/nmodes_(6, 6)")
  100.  
  101. print(count_params(model))
  102.  
  103. sim_and_render("6modes", model, data[0])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement