Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python3
- """
- Krea2 fal LoRA → ComfyUI conversion script
- ==========================================
- Converts LoRA files trained with Krea2 (fal.ai format) to ComfyUI-compatible
- safetensors by remapping key prefixes.
- The fal/PEFT format wraps all weights under:
- base_model.model.<layer>
- ComfyUI expects them directly under the model prefix, e.g.:
- diffusion_model.<layer>
- Usage
- -----
- python convert_krea2_lora_to_comfy.py input.safetensors [output.safetensors]
- If no output path is given, the script writes <input>_comfy.safetensors
- next to the input file.
- """
- import sys
- import os
- import argparse
- from pathlib import Path
- try:
- from safetensors.torch import load_file, save_file
- except ImportError:
- print("ERROR: safetensors is not installed. Run: pip install safetensors torch")
- sys.exit(1)
- # ---------------------------------------------------------------------------
- # Key-mapping rules
- # ---------------------------------------------------------------------------
- # Krea2 uses a non-Flux architecture. Keys seen in the warning log follow the
- # pattern:
- #
- # base_model.model.blocks.<N>.attn.{gate,wk,wo,wq,wv}.lora_{A,B}.weight
- # base_model.model.blocks.<N>.mlp.{down,gate,up}.lora_{A,B}.weight
- # base_model.model.first.lora_{A,B}.weight
- # base_model.model.last.linear.lora_{A,B}.weight
- # base_model.model.tmlp.{0,2}.lora_{A,B}.weight
- # base_model.model.tproj.1.lora_{A,B}.weight
- # base_model.model.txtfusion.layerwise_blocks.<N>.*
- # base_model.model.txtfusion.refiner_blocks.<N>.*
- # base_model.model.txtfusion.projector.lora_{A,B}.weight
- # base_model.model.txtmlp.{1,3}.lora_{A,B}.weight
- #
- # ComfyUI looks for the model weights under diffusion_model.*
- # (same convention used by Flux / SD3 / etc.).
- #
- # Strategy: strip the "base_model.model." prefix and replace with
- # "diffusion_model."
- # ---------------------------------------------------------------------------
- SOURCE_PREFIX = "base_model.model."
- TARGET_PREFIX = "diffusion_model."
- def remap_key(key: str) -> str:
- """Return the ComfyUI-compatible key for a given fal LoRA key."""
- if key.startswith(SOURCE_PREFIX):
- return TARGET_PREFIX + key[len(SOURCE_PREFIX):]
- # Already in another format or global key – pass through unchanged
- return key
- def convert(input_path: Path, output_path: Path) -> None:
- print(f"Loading : {input_path}")
- state_dict = load_file(str(input_path))
- original_keys = list(state_dict.keys())
- print(f"Keys found: {len(original_keys)}")
- converted: dict = {}
- remapped = 0
- unchanged = 0
- for key, tensor in state_dict.items():
- new_key = remap_key(key)
- if new_key != key:
- remapped += 1
- else:
- unchanged += 1
- converted[new_key] = tensor
- print(f"Remapped : {remapped} keys ({SOURCE_PREFIX!r} → {TARGET_PREFIX!r})")
- if unchanged:
- print(f"Unchanged : {unchanged} keys (no matching prefix)")
- # Show a few sample mappings so the user can sanity-check
- sample_keys = [k for k in original_keys if k.startswith(SOURCE_PREFIX)][:5]
- if sample_keys:
- print("\nSample key mappings:")
- for k in sample_keys:
- print(f" {k}")
- print(f" → {remap_key(k)}")
- print(f"\nSaving to : {output_path}")
- save_file(converted, str(output_path))
- print("Done ✓")
- def main() -> None:
- parser = argparse.ArgumentParser(
- description="Convert Krea2 fal LoRA to ComfyUI format"
- )
- parser.add_argument("input", help="Path to the input .safetensors file")
- parser.add_argument(
- "output",
- nargs="?",
- default=None,
- help="Path for the output .safetensors file (default: <input>_comfy.safetensors)",
- )
- args = parser.parse_args()
- input_path = Path(args.input)
- if not input_path.exists():
- print(f"ERROR: Input file not found: {input_path}")
- sys.exit(1)
- if input_path.suffix.lower() != ".safetensors":
- print("WARNING: Input file does not have a .safetensors extension.")
- if args.output:
- output_path = Path(args.output)
- else:
- output_path = input_path.with_name(input_path.stem + "_comfy.safetensors")
- convert(input_path, output_path)
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment