EnnDee

Huggingface Model download & merge to .safetensors

Aug 26th, 2025
122
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.81 KB | Source Code | 0 0
  1. import os
  2. import json
  3. import shutil
  4. import sys
  5. import tkinter as tk
  6. from tkinter import simpledialog, messagebox
  7. from huggingface_hub import snapshot_download
  8. from safetensors import safe_open
  9. from safetensors.torch import save_file
  10. from urllib.parse import urlparse
  11.  
  12. def merge_safetensors_from_hf():
  13.     """
  14.    Opens a GUI window to enter a Hugging Face URL, downloads the model,
  15.    merges the .safetensors files, renames the file according to the model,
  16.    moves it to the script's location, and optionally cleans up the cache.
  17.    """
  18.     root = tk.Tk()
  19.     root.withdraw()
  20.     download_path = None  # Initialize in case of an early error
  21.  
  22.     try:
  23.         # 1. Get Hugging Face URL from the user via GUI
  24.         hf_url = simpledialog.askstring("Hugging Face Model Downloader",
  25.                                         "Please enter the Hugging Face URL of the model:",
  26.                                         parent=root)
  27.  
  28.         if not hf_url:
  29.             messagebox.showinfo("Cancelled", "Operation cancelled.")
  30.             return
  31.  
  32.         # 2. Parse the URL to extract the repository ID and model name
  33.         parsed_url = urlparse(hf_url)
  34.         path_parts = parsed_url.path.strip('/').split('/')
  35.  
  36.         if len(path_parts) >= 2 and parsed_url.netloc == "huggingface.co":
  37.             repo_id = f"{path_parts[0]}/{path_parts[1]}"
  38.             model_name = path_parts[1]  # The part after the username
  39.         else:
  40.             messagebox.showerror("Error", "Invalid Hugging Face URL.\nThe format should be 'https://huggingface.co/user/model'.")
  41.             return
  42.  
  43.         # 3. Download the model files
  44.         messagebox.showinfo("Download Starting", f"Repository: {repo_id}\n\nThe download will now begin. This may take some time. The program might not respond during this process.")
  45.         download_path = snapshot_download(repo_id=repo_id)
  46.  
  47.         # 4. Find and load the index file
  48.         index_path = os.path.join(download_path, "model.safetensors.index.json")
  49.         if not os.path.exists(index_path):
  50.             messagebox.showerror("Error", "'model.safetensors.index.json' not found.\n\nThis script is only suitable for models with split .safetensors files.")
  51.             return
  52.  
  53.         with open(index_path, "r") as f:
  54.             index = json.load(f)
  55.  
  56.         # 5. Merge tensors from all files
  57.         merged_tensors = {}
  58.         safetensor_files = set(index["weight_map"].values())
  59.         for safetensor_file in safetensor_files:
  60.             file_path = os.path.join(download_path, safetensor_file)
  61.             with safe_open(file_path, framework="pt", device="cpu") as f:
  62.                 for key in f.keys():
  63.                     merged_tensors[key] = f.get_tensor(key)
  64.  
  65.         # 6. Temporarily save the merged tensors in the download folder
  66.         temp_output_path = os.path.join(download_path, "merged_model.safetensors")
  67.         save_file(merged_tensors, temp_output_path)
  68.  
  69.         # 7. Move the file to the script's location and rename it
  70.         script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
  71.         final_filename = f"{model_name}.safetensors"
  72.         final_destination_path = os.path.join(script_dir, final_filename)
  73.  
  74.         shutil.move(temp_output_path, final_destination_path)
  75.  
  76.         # 8. Success message for the move and preparation for cleanup
  77.         messagebox.showinfo("Model Merged", f"The model was successfully created and saved here:\n\n{final_destination_path}")
  78.  
  79.         # 9. Ask whether the temporary data should be deleted
  80.         cleanup_decision = messagebox.askyesno("Cleanup", "Would you like to delete the temporary download data from the cache to free up space?\n\n(This will remove the original download files, not your new merged file.)")
  81.         if cleanup_decision:
  82.             shutil.rmtree(download_path)
  83.             messagebox.showinfo("Cleaned Up", f"The temporary folder has been deleted:\n\n{download_path}")
  84.  
  85.         messagebox.showinfo("Done!", "The process has been successfully completed.")
  86.  
  87.     except Exception as e:
  88.         messagebox.showerror("An error occurred", f"An unexpected error occurred:\n\n{str(e)}")
  89.         # Optional cleanup on error if the download path exists
  90.         if download_path and os.path.exists(download_path):
  91.             cleanup_decision = messagebox.askyesno("Error: Cleanup", "An error has occurred. Would you still like to delete the incomplete download data?")
  92.             if cleanup_decision:
  93.                 shutil.rmtree(download_path)
  94.                 messagebox.showinfo("Cleaned Up", "Incomplete download data has been deleted.")
  95.  
  96.     finally:
  97.         root.destroy()
  98.  
  99. if __name__ == "__main__":
  100.     # Redirect error messages that are written directly to the console
  101.     # to prevent the window from "flashing" on an error.
  102.     sys.stderr = open(os.devnull, 'w')
  103.     merge_safetensors_from_hf()
Advertisement
Add Comment
Please, Sign In to add comment