Advertisement
Guest User

prune.py

a guest
Sep 16th, 2022
21
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.70 KB | None | 0 0
  1. import os
  2. import torch
  3.  
  4.  
  5. def prune_it(p, keep_only_ema=False):
  6. print(f"prunin' in path: {p}")
  7. size_initial = os.path.getsize(p)
  8. nsd = dict()
  9. sd = torch.load(p, map_location="cpu")
  10. print(sd.keys())
  11. for k in sd.keys():
  12. if k != "optimizer_states":
  13. nsd[k] = sd[k]
  14. else:
  15. print(f"removing optimizer states for path {p}")
  16. if "global_step" in sd:
  17. print(f"This is global step {sd['global_step']}.")
  18. if keep_only_ema:
  19. sd = nsd["state_dict"].copy()
  20. # infer ema keys
  21. ema_keys = {k: "model_ema." + k[6:].replace(".", ".") for k in sd.keys() if k.startswith("model.")}
  22. new_sd = dict()
  23.  
  24. for k in sd:
  25. if k in ema_keys:
  26. new_sd[k] = sd[ema_keys[k]].half()
  27. elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
  28. new_sd[k] = sd[k].half()
  29.  
  30. assert len(new_sd) == len(sd) - len(ema_keys)
  31. nsd["state_dict"] = new_sd
  32. else:
  33. sd = nsd['state_dict'].copy()
  34. new_sd = dict()
  35. for k in sd:
  36. new_sd[k] = sd[k].half()
  37. nsd['state_dict'] = new_sd
  38.  
  39. fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
  40. print(f"saving pruned checkpoint at: {fn}")
  41. torch.save(nsd, fn)
  42. newsize = os.path.getsize(fn)
  43. MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
  44. f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states"
  45. if keep_only_ema:
  46. MSG += " and non-EMA weights"
  47. print(MSG)
  48.  
  49.  
  50. if __name__ == "__main__":
  51. prune_it('YOUR-MODEL-HERE.ckpt')
  52.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement