Guest User

merge.py

a guest
Sep 16th, 2022
29
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.51 KB | None | 0 0
  1. import os
  2. import argparse
  3. import torch
  4. from tqdm import tqdm
  5.  
  6. parser = argparse.ArgumentParser(description="Merge two models")
  7. parser.add_argument("model_0", type=str, help="Path to model 0")
  8. parser.add_argument("model_1", type=str, help="Path to model 1")
  9. parser.add_argument("--alpha", type=float, help="Alpha value, optional, defaults to 0.5", default=0.5, required=False)
  10. parser.add_argument("--output", type=str, help="Output file name, without extension", default="merged", required=False)
  11.  
  12. args = parser.parse_args()
  13.  
  14. model_0 = torch.load(args.model_0)
  15. model_1 = torch.load(args.model_1)
  16. theta_0 = model_0["state_dict"]
  17. theta_1 = model_1["state_dict"]
  18. alpha = args.alpha
  19.  
  20. output_file = f'{args.output}-{str(alpha)[2:] + "0"}.ckpt'
  21.  
  22. # check if output file already exists, ask to overwrite
  23. if os.path.isfile(output_file):
  24. print("Output file already exists. Overwrite? (y/n)")
  25. while True:
  26. overwrite = input()
  27. if overwrite == "y":
  28. break
  29. elif overwrite == "n":
  30. print("Exiting...")
  31. exit()
  32. else:
  33. print("Please enter y or n")
  34.  
  35.  
  36. for key in tqdm(theta_0.keys(), desc="Stage 1/2"):
  37. if "model" in key and key in theta_1:
  38. theta_0[key] = (1 - alpha) * theta_0[key] + alpha * theta_1[key]
  39.  
  40. for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
  41. if "model" in key and key not in theta_0:
  42. theta_0[key] = theta_1[key]
  43.  
  44. print("Saving...")
  45.  
  46. torch.save({"state_dict": theta_0}, output_file)
  47.  
  48. print("Done!")
  49.  
Add Comment
Please, Sign In to add comment