Guest User

Untitled

a guest
Jul 17th, 2025
197
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.15 KB | None | 0 0
  1. import os
  2. import torch
  3. import safetensors.torch
  4. from safetensors import safe_open
  5.  
  6. def patch_final_layer_adaLN(state_dict, prefix="lora_unet_final_layer", verbose=True):
  7. final_layer_linear_down = None
  8. final_layer_linear_up = None
  9.  
  10. adaLN_down_key = f"{prefix}_adaLN_modulation_1.lora_down.weight"
  11. adaLN_up_key = f"{prefix}_adaLN_modulation_1.lora_up.weight"
  12. linear_down_key = f"{prefix}_linear.lora_down.weight"
  13. linear_up_key = f"{prefix}_linear.lora_up.weight"
  14.  
  15. if verbose:
  16. print(f"\nšŸ” Checking for final_layer keys with prefix: '{prefix}'")
  17.  
  18. if linear_down_key in state_dict:
  19. final_layer_linear_down = state_dict[linear_down_key]
  20. if linear_up_key in state_dict:
  21. final_layer_linear_up = state_dict[linear_up_key]
  22.  
  23. has_adaLN = adaLN_down_key in state_dict and adaLN_up_key in state_dict
  24. has_linear = final_layer_linear_down is not None and final_layer_linear_up is not None
  25.  
  26. if verbose:
  27. print(f" āœ… Has final_layer.linear: {has_linear}")
  28. print(f" āœ… Has final_layer.adaLN_modulation_1: {has_adaLN}")
  29.  
  30. if has_linear and not has_adaLN:
  31. dummy_down = torch.zeros_like(final_layer_linear_down)
  32. dummy_up = torch.zeros_like(final_layer_linear_up)
  33. state_dict[adaLN_down_key] = dummy_down
  34. state_dict[adaLN_up_key] = dummy_up
  35.  
  36. if verbose:
  37. print(f"āœ… Added dummy adaLN weights.")
  38. return True # Was patched
  39. return False # Nothing changed
  40.  
  41. def patch_file(input_path, output_path):
  42. state_dict = {}
  43. with safe_open(input_path, framework="pt", device="cpu") as f:
  44. for k in f.keys():
  45. state_dict[k] = f.get_tensor(k)
  46.  
  47. patched = False
  48. prefixes = [
  49. "lora_unet_final_layer",
  50. "final_layer",
  51. "base_model.model.final_layer"
  52. ]
  53.  
  54. for prefix in prefixes:
  55. before = len(state_dict)
  56. did_patch = patch_final_layer_adaLN(state_dict, prefix=prefix, verbose=False)
  57. after = len(state_dict)
  58. if did_patch and after > before:
  59. patched = True
  60. break
  61.  
  62. if patched:
  63. safetensors.torch.save_file(state_dict, output_path)
  64. print(f"āœ… Patched and saved: {os.path.basename(output_path)}")
  65. else:
  66. print(f"āš ļø Skipped (already has adaLN or missing final_layer): {os.path.basename(input_path)}")
  67.  
  68. def main():
  69. print("šŸ”„ Batch LoRA adaLN Patcher")
  70. input_folder = input("Enter input folder path: ").strip()
  71. output_folder = input("Enter output folder path: ").strip()
  72.  
  73. if not os.path.isdir(input_folder):
  74. print("āŒ Invalid input folder.")
  75. return
  76. os.makedirs(output_folder, exist_ok=True)
  77.  
  78. files = [f for f in os.listdir(input_folder) if f.endswith(".safetensors")]
  79. print(f"\nšŸ“‚ Found {len(files)} .safetensors files in: {input_folder}")
  80.  
  81. for filename in files:
  82. in_path = os.path.join(input_folder, filename)
  83. out_path = os.path.join(output_folder, filename)
  84. patch_file(in_path, out_path)
  85.  
  86. print("\nāœ… Done. Patched files saved to:", output_folder)
  87.  
  88. if __name__ == "__main__":
  89. main()
  90.  
Advertisement
Add Comment
Please, Sign In to add comment