Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import torch
- import safetensors.torch
- from safetensors import safe_open
- def patch_final_layer_adaLN(state_dict, prefix="lora_unet_final_layer", verbose=True):
- final_layer_linear_down = None
- final_layer_linear_up = None
- adaLN_down_key = f"{prefix}_adaLN_modulation_1.lora_down.weight"
- adaLN_up_key = f"{prefix}_adaLN_modulation_1.lora_up.weight"
- linear_down_key = f"{prefix}_linear.lora_down.weight"
- linear_up_key = f"{prefix}_linear.lora_up.weight"
- if verbose:
- print(f"\nš Checking for final_layer keys with prefix: '{prefix}'")
- if linear_down_key in state_dict:
- final_layer_linear_down = state_dict[linear_down_key]
- if linear_up_key in state_dict:
- final_layer_linear_up = state_dict[linear_up_key]
- has_adaLN = adaLN_down_key in state_dict and adaLN_up_key in state_dict
- has_linear = final_layer_linear_down is not None and final_layer_linear_up is not None
- if verbose:
- print(f" ā Has final_layer.linear: {has_linear}")
- print(f" ā Has final_layer.adaLN_modulation_1: {has_adaLN}")
- if has_linear and not has_adaLN:
- dummy_down = torch.zeros_like(final_layer_linear_down)
- dummy_up = torch.zeros_like(final_layer_linear_up)
- state_dict[adaLN_down_key] = dummy_down
- state_dict[adaLN_up_key] = dummy_up
- if verbose:
- print(f"ā Added dummy adaLN weights.")
- return True # Was patched
- return False # Nothing changed
- def patch_file(input_path, output_path):
- state_dict = {}
- with safe_open(input_path, framework="pt", device="cpu") as f:
- for k in f.keys():
- state_dict[k] = f.get_tensor(k)
- patched = False
- prefixes = [
- "lora_unet_final_layer",
- "final_layer",
- "base_model.model.final_layer"
- ]
- for prefix in prefixes:
- before = len(state_dict)
- did_patch = patch_final_layer_adaLN(state_dict, prefix=prefix, verbose=False)
- after = len(state_dict)
- if did_patch and after > before:
- patched = True
- break
- if patched:
- safetensors.torch.save_file(state_dict, output_path)
- print(f"ā Patched and saved: {os.path.basename(output_path)}")
- else:
- print(f"ā ļø Skipped (already has adaLN or missing final_layer): {os.path.basename(input_path)}")
- def main():
- print("š Batch LoRA adaLN Patcher")
- input_folder = input("Enter input folder path: ").strip()
- output_folder = input("Enter output folder path: ").strip()
- if not os.path.isdir(input_folder):
- print("ā Invalid input folder.")
- return
- os.makedirs(output_folder, exist_ok=True)
- files = [f for f in os.listdir(input_folder) if f.endswith(".safetensors")]
- print(f"\nš Found {len(files)} .safetensors files in: {input_folder}")
- for filename in files:
- in_path = os.path.join(input_folder, filename)
- out_path = os.path.join(output_folder, filename)
- patch_file(in_path, out_path)
- print("\nā Done. Patched files saved to:", output_folder)
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment