Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import json
- import shutil
- import sys
- import tkinter as tk
- from tkinter import simpledialog, messagebox
- from huggingface_hub import snapshot_download
- from safetensors import safe_open
- from safetensors.torch import save_file
- from urllib.parse import urlparse
- def merge_safetensors_from_hf():
- """
- Opens a GUI window to enter a Hugging Face URL, downloads the model,
- merges the .safetensors files, renames the file according to the model,
- moves it to the script's location, and optionally cleans up the cache.
- """
- root = tk.Tk()
- root.withdraw()
- download_path = None # Initialize in case of an early error
- try:
- # 1. Get Hugging Face URL from the user via GUI
- hf_url = simpledialog.askstring("Hugging Face Model Downloader",
- "Please enter the Hugging Face URL of the model:",
- parent=root)
- if not hf_url:
- messagebox.showinfo("Cancelled", "Operation cancelled.")
- return
- # 2. Parse the URL to extract the repository ID and model name
- parsed_url = urlparse(hf_url)
- path_parts = parsed_url.path.strip('/').split('/')
- if len(path_parts) >= 2 and parsed_url.netloc == "huggingface.co":
- repo_id = f"{path_parts[0]}/{path_parts[1]}"
- model_name = path_parts[1] # The part after the username
- else:
- messagebox.showerror("Error", "Invalid Hugging Face URL.\nThe format should be 'https://huggingface.co/user/model'.")
- return
- # 3. Download the model files
- messagebox.showinfo("Download Starting", f"Repository: {repo_id}\n\nThe download will now begin. This may take some time. The program might not respond during this process.")
- download_path = snapshot_download(repo_id=repo_id)
- # 4. Find and load the index file
- index_path = os.path.join(download_path, "model.safetensors.index.json")
- if not os.path.exists(index_path):
- messagebox.showerror("Error", "'model.safetensors.index.json' not found.\n\nThis script is only suitable for models with split .safetensors files.")
- return
- with open(index_path, "r") as f:
- index = json.load(f)
- # 5. Merge tensors from all files
- merged_tensors = {}
- safetensor_files = set(index["weight_map"].values())
- for safetensor_file in safetensor_files:
- file_path = os.path.join(download_path, safetensor_file)
- with safe_open(file_path, framework="pt", device="cpu") as f:
- for key in f.keys():
- merged_tensors[key] = f.get_tensor(key)
- # 6. Temporarily save the merged tensors in the download folder
- temp_output_path = os.path.join(download_path, "merged_model.safetensors")
- save_file(merged_tensors, temp_output_path)
- # 7. Move the file to the script's location and rename it
- script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
- final_filename = f"{model_name}.safetensors"
- final_destination_path = os.path.join(script_dir, final_filename)
- shutil.move(temp_output_path, final_destination_path)
- # 8. Success message for the move and preparation for cleanup
- messagebox.showinfo("Model Merged", f"The model was successfully created and saved here:\n\n{final_destination_path}")
- # 9. Ask whether the temporary data should be deleted
- cleanup_decision = messagebox.askyesno("Cleanup", "Would you like to delete the temporary download data from the cache to free up space?\n\n(This will remove the original download files, not your new merged file.)")
- if cleanup_decision:
- shutil.rmtree(download_path)
- messagebox.showinfo("Cleaned Up", f"The temporary folder has been deleted:\n\n{download_path}")
- messagebox.showinfo("Done!", "The process has been successfully completed.")
- except Exception as e:
- messagebox.showerror("An error occurred", f"An unexpected error occurred:\n\n{str(e)}")
- # Optional cleanup on error if the download path exists
- if download_path and os.path.exists(download_path):
- cleanup_decision = messagebox.askyesno("Error: Cleanup", "An error has occurred. Would you still like to delete the incomplete download data?")
- if cleanup_decision:
- shutil.rmtree(download_path)
- messagebox.showinfo("Cleaned Up", "Incomplete download data has been deleted.")
- finally:
- root.destroy()
- if __name__ == "__main__":
- # Redirect error messages that are written directly to the console
- # to prevent the window from "flashing" on an error.
- sys.stderr = open(os.devnull, 'w')
- merge_safetensors_from_hf()
Advertisement
Add Comment
Please, Sign In to add comment