Advertisement
fakeng

viz_checkpoint

Jul 18th, 2023
45
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.77 KB | None | 0 0
  1. import csv
  2. from safetensors.torch import load_file
  3. import torch
  4. from pathlib import Path
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from tqdm import tqdm
  8. import time
  9.  
  10. def cal_cross_attn(to_q, to_k, to_v, rand_input):
  11. hidden_dim, embed_dim = to_q.shape
  12. attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
  13. attn_to_k = nn.Linear(hidden_dim, embed_dim, bias=False)
  14. attn_to_v = nn.Linear(hidden_dim, embed_dim, bias=False)
  15. attn_to_q.load_state_dict({"weight": to_q})
  16. attn_to_k.load_state_dict({"weight": to_k})
  17. attn_to_v.load_state_dict({"weight": to_v})
  18.  
  19. return torch.einsum(
  20. "ik, jk -> ik",
  21. F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1),
  22. attn_to_v(rand_input)
  23. )
  24.  
  25. def model_hash(filename):
  26. try:
  27. with open(filename, "rb") as file:
  28. import hashlib
  29. m = hashlib.sha256()
  30.  
  31. file.seek(0x100000)
  32. m.update(file.read(0x10000))
  33. return m.hexdigest()[0:8]
  34. except FileNotFoundError:
  35. return 'NOFILE'
  36.  
  37. def load_model(path):
  38. if path.suffix == ".safetensors":
  39. return load_file(path, device="cpu")
  40. else:
  41. ckpt = torch.load(path, map_location="cpu")
  42. return ckpt["state_dict"] if "state_dict" in ckpt else ckpt
  43.  
  44. def eval(model, n, input):
  45. qk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_q.weight"
  46. uk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_k.weight"
  47. vk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_v.weight"
  48. atoq, atok, atov = model[qk], model[uk], model[vk]
  49.  
  50. attn = cal_cross_attn(atoq, atok, atov, input)
  51. return attn
  52.  
  53. def compare_checkpoints(file1, file2, csv_writer):
  54. model_a = load_model(file1)
  55. model_b = load_model(file2)
  56.  
  57. map_attn_a = {}
  58. map_attn_b = {}
  59. map_rand_input = {}
  60.  
  61. for n in range(3, 11):
  62. hidden_dim, embed_dim = model_a[f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_q.weight"].shape
  63. rand_input = torch.randn([embed_dim, hidden_dim])
  64.  
  65. map_attn_a[n] = eval(model_a, n, rand_input)
  66. map_attn_b[n] = eval(model_b, n, rand_input)
  67. map_rand_input[n] = rand_input
  68.  
  69. sims = []
  70. for n in range(3, 11):
  71. attn_a = map_attn_a[n]
  72. attn_b = map_attn_b[n]
  73.  
  74. sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
  75. sims.append(sim)
  76.  
  77. similarity = torch.mean(torch.stack(sims)) * 100
  78.  
  79. csv_writer.writerow([file1.name, file2.name, similarity.item()])
  80.  
  81.  
  82. # List of file paths to compare
  83. file_paths = [
  84. "Basil mix.safetensors",
  85. "CLR+Izumi+BRAv5.safetensors",
  86. "CLRL_IzumiBarV5.safetensors",
  87. "ChillLofiRealistcv2.safetensors",
  88. "DreamLikeNovelInkF222VisionRealism.safetensors",
  89. "HenxxmixReal.safetensors",
  90. "V08_V08.safetensors",
  91. "anything-v4.0.ckpt",
  92. "asianRole_v10.safetensors",
  93. "babes_11.safetensors",
  94. "bra_v5.safetensors",
  95. "chilloutmix_NiPrunedFp32Fix.safetensors",
  96. "chineseDigitalArt_10.ckpt",
  97. "clarity_19.safetensors",
  98. "clarity_2.safetensors",
  99. "deliberate_v2.ckpt",
  100. "dreamlike-photoreal-2.0.ckpt",
  101. "dreamshaper_4BakedVae.safetensors",
  102. "dungeonsNWaifusNew_dungeonsNWaifus22.safetensors",
  103. "dvarchDreamlikePhotReal.safetensors",
  104. "dvarchDreamlikePhotRealAsianRoleAOM3A1B.safetensors",
  105. "f222_v1.ckpt",
  106. "fotoAssisted_v0.safetensors",
  107. "hassanblend1512And_hassanblend1512.ckpt",
  108. "icbinpICantBelieveIts_afterburn.safetensors",
  109. "izumi_01Safetensors.safetensors",
  110. "koreanstyle25D_koreanstyle25DBaked.safetensors",
  111. "lofi_V2pre.safetensors",
  112. "majicmixFantasy_v20.safetensors",
  113. "photon_v1.safetensors",
  114. "juggernaut_final.safetensors",
  115. "realisticVisionV20_v20.safetensors",
  116. "uberRealisticPornMerge_urpmv13.safetensors",
  117. "v1-5-pruned.ckpt",
  118. "v2-1_768-ema-pruned.safetensors",
  119. "xxmix9realistic_v26.safetensors"
  120. ]
  121.  
  122. # Create and open the CSV file
  123. with open('checkpoint_similarity.csv', mode='w', newline='') as file:
  124. writer = csv.writer(file)
  125. writer.writerow(['Source', 'Target', 'Weight'])
  126.  
  127. # Total number of comparisons
  128. total_comparisons = (len(file_paths) * (len(file_paths) - 1)) // 2
  129.  
  130. # Progress bar initialization
  131. progress_bar = tqdm(total=total_comparisons, ncols=80)
  132.  
  133. # Comparing each file with others
  134. for i, file1 in enumerate(file_paths):
  135. for file2 in file_paths[i + 1:]:
  136. compare_checkpoints(Path(file1), Path(file2), writer)
  137. time.sleep(0.01) # Simulating computation time
  138. progress_bar.update(1)
  139.  
  140. # Completion message
  141. print("It is done!")
  142.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement