Guest User

DELLA merge method of Z-image Turbo

a guest
Apr 30th, 2026
387
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.73 KB | Source Code | 0 0
  1. import torch
  2. from safetensors.torch import load_file, save_file
  3. import os
  4. import gc
  5.  
  6. def della_merge_streaming(base_path, variant_paths, weights, density=0.6, lam=1.1):
  7.     """
  8.    Hardware-optimized DELLA merge for 32GB RAM systems.
  9.    """
  10.     print(f"--- Starting DELLA Merge (Streaming Mode) ---")
  11.    
  12.     # 1. Load Base Model once (stays in RAM as the 'canvas')
  13.     print(f"Loading Base: {base_path}")
  14.     base_state = load_file(base_path, device="cpu")
  15.    
  16.     # Initialize a buffer for the merged deltas (Float32 for math precision)
  17.     # We create this to avoid modifying the base_state until the very end
  18.     merged_deltas = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in base_state.items()}
  19.  
  20.     for path, weight in zip(variant_paths, weights):
  21.         if not os.path.exists(path):
  22.             print(f"Warning: Model not found at {path}. Skipping...")
  23.             continue
  24.            
  25.         print(f"Processing: {os.path.basename(path)} with weight {weight}")
  26.         variant_state = load_file(path, device="cpu")
  27.        
  28.         for k in merged_deltas.keys():
  29.             if k not in variant_state: continue
  30.            
  31.             # DELLA Step 1: Calculate Delta (Variant - Base)
  32.             # Use .float() for precise magnitude ranking
  33.             delta = variant_state[k].to(torch.float32) - base_state[k].to(torch.float32)
  34.            
  35.             # DELLA Step 2: MagPrune
  36.             # Only keep the top 'density' % of the strongest weights
  37.             num_params = delta.numel()
  38.             top_k = int(density * num_params)
  39.            
  40.             if top_k > 0 and delta.dim() > 0:
  41.                 flat_delta = delta.abs().view(-1)
  42.                 # Find the value that separates top-k from the rest
  43.                 threshold = torch.kthvalue(flat_delta, num_params - top_k + 1).values
  44.                 mask = (delta.abs() >= threshold).float()
  45.                
  46.                 # DELLA Step 3: Rescale and Accumulate
  47.                 # Scaling by 1/density is the 'Rescale' part of DELLA
  48.                 rescale = 1.0 / density
  49.                 merged_deltas[k] += (delta * mask * weight * rescale * lam)
  50.        
  51.         # Free memory before loading the next model
  52.         del variant_state
  53.         gc.collect()
  54.  
  55.     print("Finalizing Merge...")
  56.     # 2. Add accumulated deltas back to the base model
  57.     for k in base_state.keys():
  58.         # Cast back to BF16 (standard for Z-Image) or FP16 for the final file
  59.         final_weight = base_state[k].to(torch.float32) + merged_deltas[k]
  60.         base_state[k] = final_weight.to(torch.bfloat16)
  61.         # Clear delta buffer to free RAM as we go
  62.         del merged_deltas[k]
  63.  
  64.     return base_state
  65.  
  66. # --- OPTIMIZED FOR TUNED KNOWLEDGE (Variant Dominant) ---
  67.  
  68. # 1. Path to your Z-Image Turbo Base (The anchor)
  69. BASE_MODEL = "zBastardv45.safetensors"
  70.  
  71. # 2. Paths to your specialized versions
  72. VARIANTS = [
  73.     "V5.5.safetensors",
  74.     "InbreedV5.safetensors",
  75.     "V6.safetensors"
  76. ]
  77.  
  78. # 3. Weights: Higher than 1.0 is fine here!
  79. # We want these to 'overwrite' the base model's default logic.
  80. # The first model in the list is usually your 'Primary' style.
  81. WEIGHTS = [1.0,0.5, 0.9]
  82.  
  83. # 4. DELLA Specific Hyperparameters
  84. # DENSITY: 0.75 means we keep 75% of the unique knowledge from the variants.
  85. # This prevents the base model from 'diluting' the fine-tuned details.
  86. DENSITY = 0.95
  87.  
  88. # LAMBDA: 1.15 acts as a signal booster for the 'expert' weights that survive.
  89. LAMBDA = 1.10
  90.  
  91. # --- EXECUTION ---
  92. merged_model = della_merge_streaming(BASE_MODEL, VARIANTS, WEIGHTS, density=DENSITY, lam=LAMBDA)
  93.  
  94. # Save the final file
  95. output_name = "BastarDELLA_V4.safetensors"
  96. save_file(merged_model, output_name)
  97. print(f"Success! Merged model saved as: {output_name}")
Tags: DELLA
Advertisement
Add Comment
Please, Sign In to add comment