BuccoBruce

mirror of final version of https://github.com/harubaru/waifu-diffusion/tree/main/scripts/prune.py

Jan 3rd, 2023
1,058
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.91 KB | Source Code | 0 0
  1. import os
  2. from pathlib import Path
  3. import torch
  4. import argparse
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument('--input', '-I', type=str, help='Input file to prune', required = True)
  7. args = parser.parse_args()
  8. file = args.input
  9.  
  10.  
  11. def prune_it(p, keep_only_ema=True):
  12.     print(f"prunin' in path: {p}")
  13.     size_initial = os.path.getsize(p)
  14.     nsd = dict()
  15.     sd = torch.load(p, map_location="cpu")
  16.     print(sd.keys())
  17.     for k in sd.keys():
  18.         if k != "optimizer_states":
  19.             nsd[k] = sd[k]
  20.     else:
  21.         print(f"removing optimizer states for path {p}")
  22.     if "global_step" in sd:
  23.         print(f"This is global step {sd['global_step']}.")
  24.     if keep_only_ema:
  25.         sd = nsd["state_dict"].copy()
  26.         # infer ema keys
  27.         ema_keys = {k: "model_ema." + k[6:].replace(".", "") for k in sd.keys() if k.startswith('model.')}
  28.         new_sd = dict()
  29.  
  30.         for k in sd:
  31.             if k in ema_keys:
  32.                 print(k, ema_keys[k])
  33.                 new_sd[k] = sd[ema_keys[k]]
  34.             elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
  35.                 new_sd[k] = sd[k]
  36.  
  37.         assert len(new_sd) == len(sd) - len(ema_keys)
  38.         nsd["state_dict"] = new_sd
  39.     else:
  40.         sd = nsd['state_dict'].copy()
  41.         new_sd = dict()
  42.         for k in sd:
  43.             new_sd[k] = sd[k]
  44.         nsd['state_dict'] = new_sd
  45.  
  46.     fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
  47.     print(f"saving pruned checkpoint at: {fn}")
  48.     torch.save(nsd, fn)
  49.     newsize = os.path.getsize(fn)
  50.     MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
  51.           f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states"
  52.     if keep_only_ema:
  53.         MSG += " and non-EMA weights"
  54.     print(MSG)
  55.  
  56.  
  57. if __name__ == "__main__":
  58.     prune_it(file)
Advertisement
Add Comment
Please, Sign In to add comment