Advertisement
Guest User

Script to merge safetensor part files into a single file

a guest
Feb 16th, 2025
1,183
1
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.17 KB | None | 1 0
  1. import json
  2. import safetensors
  3. import safetensors.torch
  4. import torch
  5. import os
  6.  
  7. def merge_safetensors_shards(index_file_path, output_file_path):
  8. """
  9. Merges safetensors shards into a single safetensors file.
  10.  
  11. Args:
  12. index_file_path (str): Path to the model.safetensors.index.json file.
  13. output_file_path (str): Path to save the merged safetensors file.
  14. """
  15.  
  16. try:
  17. with open(index_file_path, 'r') as f:
  18. index_data = json.load(f)
  19. except FileNotFoundError:
  20. print(f"Error: Index file not found at {index_file_path}")
  21. return
  22.  
  23. shard_file_directory = os.path.dirname(index_file_path)
  24. merged_weights = {}
  25. total_size = 0
  26.  
  27. print("Starting to merge safetensors shards...")
  28.  
  29. for weight_name, shard_file in index_data['weight_map'].items():
  30. shard_path = os.path.join(shard_file_directory, shard_file)
  31. try:
  32. with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
  33. tensor = f.get_tensor(weight_name)
  34. merged_weights[weight_name] = tensor
  35. total_size += tensor.numel() * tensor.element_size() # Calculate size in bytes
  36. print(f"Loaded weight '{weight_name}' from shard '{shard_file}'")
  37. except FileNotFoundError:
  38. print(f"Error: Shard file not found at {shard_path}")
  39. return
  40.  
  41. print(f"All shards loaded. Total size of merged weights: {total_size} bytes")
  42. print(f"Saving merged safetensors file to: {output_file_path}")
  43.  
  44. try:
  45. safetensors.torch.save_file(merged_weights, output_file_path, metadata={"format": "pt"})
  46. print(f"Successfully merged safetensors shards to {output_file_path}.")
  47. except Exception as e:
  48. print(f"Error saving merged safetensors file: {e}")
  49.  
  50.  
  51. if __name__ == "__main__":
  52. index_file_path = "diffusion_pytorch_model.safetensors.index.json" # index.json file from the repo
  53. output_file_path = "your_path/merged_diffusion_pytorch_model.safetensors" # You can change the output name if needed
  54.  
  55. merge_safetensors_shards(index_file_path, output_file_path)
  56.  
  57. print("\nMerge process completed.")
Tags: huggingface
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement