Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from safetensors.torch import load_file, save_file
- import os
- import gc
- def della_merge_streaming(base_path, variant_paths, weights, density=0.6, lam=1.1):
- """
- Hardware-optimized DELLA merge for 32GB RAM systems.
- """
- print(f"--- Starting DELLA Merge (Streaming Mode) ---")
- # 1. Load Base Model once (stays in RAM as the 'canvas')
- print(f"Loading Base: {base_path}")
- base_state = load_file(base_path, device="cpu")
- # Initialize a buffer for the merged deltas (Float32 for math precision)
- # We create this to avoid modifying the base_state until the very end
- merged_deltas = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in base_state.items()}
- for path, weight in zip(variant_paths, weights):
- if not os.path.exists(path):
- print(f"Warning: Model not found at {path}. Skipping...")
- continue
- print(f"Processing: {os.path.basename(path)} with weight {weight}")
- variant_state = load_file(path, device="cpu")
- for k in merged_deltas.keys():
- if k not in variant_state: continue
- # DELLA Step 1: Calculate Delta (Variant - Base)
- # Use .float() for precise magnitude ranking
- delta = variant_state[k].to(torch.float32) - base_state[k].to(torch.float32)
- # DELLA Step 2: MagPrune
- # Only keep the top 'density' % of the strongest weights
- num_params = delta.numel()
- top_k = int(density * num_params)
- if top_k > 0 and delta.dim() > 0:
- flat_delta = delta.abs().view(-1)
- # Find the value that separates top-k from the rest
- threshold = torch.kthvalue(flat_delta, num_params - top_k + 1).values
- mask = (delta.abs() >= threshold).float()
- # DELLA Step 3: Rescale and Accumulate
- # Scaling by 1/density is the 'Rescale' part of DELLA
- rescale = 1.0 / density
- merged_deltas[k] += (delta * mask * weight * rescale * lam)
- # Free memory before loading the next model
- del variant_state
- gc.collect()
- print("Finalizing Merge...")
- # 2. Add accumulated deltas back to the base model
- for k in base_state.keys():
- # Cast back to BF16 (standard for Z-Image) or FP16 for the final file
- final_weight = base_state[k].to(torch.float32) + merged_deltas[k]
- base_state[k] = final_weight.to(torch.bfloat16)
- # Clear delta buffer to free RAM as we go
- del merged_deltas[k]
- return base_state
- # --- OPTIMIZED FOR TUNED KNOWLEDGE (Variant Dominant) ---
- # 1. Path to your Z-Image Turbo Base (The anchor)
- BASE_MODEL = "zBastardv45.safetensors"
- # 2. Paths to your specialized versions
- VARIANTS = [
- "V5.5.safetensors",
- "InbreedV5.safetensors",
- "V6.safetensors"
- ]
- # 3. Weights: Higher than 1.0 is fine here!
- # We want these to 'overwrite' the base model's default logic.
- # The first model in the list is usually your 'Primary' style.
- WEIGHTS = [1.0,0.5, 0.9]
- # 4. DELLA Specific Hyperparameters
- # DENSITY: 0.75 means we keep 75% of the unique knowledge from the variants.
- # This prevents the base model from 'diluting' the fine-tuned details.
- DENSITY = 0.95
- # LAMBDA: 1.15 acts as a signal booster for the 'expert' weights that survive.
- LAMBDA = 1.10
- # --- EXECUTION ---
- merged_model = della_merge_streaming(BASE_MODEL, VARIANTS, WEIGHTS, density=DENSITY, lam=LAMBDA)
- # Save the final file
- output_name = "BastarDELLA_V4.safetensors"
- save_file(merged_model, output_name)
- print(f"Success! Merged model saved as: {output_name}")
Advertisement
Add Comment
Please, Sign In to add comment