Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # merge_bases.py
- import argparse
- import torch
- from safetensors.torch import load_file, save_file
- import os
- def merge_bases(base_paths, alpha, output_path):
- if len(base_paths) < 2:
- raise ValueError("Need at least two base models to merge.")
- print(f"🔄 Merging {len(base_paths)} base models with alpha={alpha}")
- state_dicts = [load_file(path) for path in base_paths]
- keys = state_dicts[0].keys()
- # Ensure all keys match
- for sd in state_dicts[1:]:
- if sd.keys() != keys:
- raise ValueError("Base models do not have matching keys.")
- merged = {}
- for k in keys:
- tensors = [sd[k] for sd in state_dicts]
- weights = [(1 - alpha) if i else alpha for i in range(len(tensors))] # first gets alpha, rest get (1-alpha)/(n-1)
- blended = sum(w * t for w, t in zip(weights, tensors))
- merged[k] = blended
- save_file(merged, output_path)
- print(f"✅ Merged base model saved to {output_path}")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Merge two or more base model files.")
- parser.add_argument('--bases', nargs='+', required=True, help='Paths to base model files (.safetensors)')
- parser.add_argument('--alpha', type=float, required=True, help='Blend weight (0.0 - 1.0) for the first model')
- parser.add_argument('--output', type=str, required=True, help='Output path for the merged base')
- args = parser.parse_args()
- merge_bases(args.bases, args.alpha, args.output)
- # python merge_bases.py --bases "C:\Users\diarrhea\Models\diffusion_models\hunyuan\hunyuan_video_v2_replace_image_to_video_720p_bf16.safetensors" "C:\Users\diarrhea\Models\diffusion_models\hunyuan\hunyuan_video_720_cfgdistill_bf16.safetensors" --alpha 0.5 --output "C:\Users\diarrhea\Models\diffusion_models\hunyuan\hunyuan_video_merge_720_i2v_+_t2v_bf16_50-50.safetensors"
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement