Guest User

Huggingface Download Script

a guest
Jan 28th, 2025
1,362
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.07 KB | Software | 0 0
  1. #!/usr/bin/env python3
  2. import os
  3. import sys
  4. import subprocess
  5. import requests
  6. import hashlib
  7. from concurrent.futures import ThreadPoolExecutor
  8. from pathlib import Path
  9. from typing import List, Optional, Set, Dict
  10. import shutil
  11. import json
  12.  
  13. class HFDownloader:
  14.     def __init__(self, username: str, output_dir: Optional[str] = None, force: bool = False):
  15.         self.username = username
  16.         self.output_dir = output_dir or Path(f"hf_downloads/{username}")
  17.         self.api_url = "https://huggingface.co/api/models"
  18.         self.force = force
  19.        
  20.     def get_user_models(self) -> List[str]:
  21.         """Fetch all model IDs for a given username."""
  22.         params = {"author": self.username}
  23.         response = requests.get(self.api_url, params=params)
  24.         response.raise_for_status()
  25.         return [model["id"] for model in response.json()]
  26.  
  27.     def get_remote_files(self, model_id: str) -> Dict[str, str]:
  28.         """Get list of files and their SHA values from remote repository."""
  29.         try:
  30.             url = f"https://huggingface.co/api/models/{model_id}/tree/main"
  31.             response = requests.get(url)
  32.             response.raise_for_status()
  33.            
  34.             files = {}
  35.             for item in response.json():
  36.                 if 'path' in item and 'lfs' in item:
  37.                     files[item['path']] = item['lfs'].get('sha256') if item['lfs'] else None
  38.                 elif 'path' in item and 'oid' in item:
  39.                     files[item['path']] = item['oid']
  40.             return files
  41.         except Exception as e:
  42.             print(f"Error fetching remote files for {model_id}: {str(e)}")
  43.             return {}
  44.  
  45.     def get_local_files(self, model_dir: Path) -> Set[str]:
  46.         """Get list of all files in local repository."""
  47.         if not model_dir.exists():
  48.             return set()
  49.        
  50.         files = set()
  51.         for root, _, filenames in os.walk(model_dir):
  52.             for filename in filenames:
  53.                 rel_path = os.path.relpath(os.path.join(root, filename), model_dir)
  54.                 files.add(rel_path)
  55.         return files
  56.  
  57.     def sync_repository(self, model_id: str) -> None:
  58.         """Synchronize local repository with remote, ensuring exact match."""
  59.         try:
  60.             model_dir = Path(self.output_dir) / model_id.split("/")[-1]
  61.            
  62.             # Get remote files and their hashes
  63.             print(f"Checking remote structure for {model_id}...")
  64.             remote_files = self.get_remote_files(model_id)
  65.            
  66.             if not remote_files:
  67.                 print(f"Error: Could not fetch remote file list for {model_id}")
  68.                 return
  69.                
  70.             # Get local files
  71.             local_files = self.get_local_files(model_dir)
  72.            
  73.             # Find extra files to remove
  74.             extra_files = local_files - set(remote_files.keys())
  75.             if extra_files:
  76.                 print(f"Removing {len(extra_files)} extra files from {model_id}")
  77.                 for file in extra_files:
  78.                     (model_dir / file).unlink()
  79.            
  80.             # Download/update repository
  81.             print(f"Synchronizing {model_id}...")
  82.             cmd = [
  83.                 "huggingface-cli", "download",
  84.                 "--resume-download",
  85.                 "--local-dir", str(model_dir),
  86.                 "--clean-cache",  # Clean the cache to ensure fresh download
  87.                 model_id
  88.             ]
  89.            
  90.             result = subprocess.run(cmd, capture_output=True, text=True)
  91.            
  92.             if result.returncode != 0:
  93.                 print(f"Error synchronizing {model_id}: {result.stderr}")
  94.                 return
  95.                
  96.             # Verify all files are present and have correct hashes
  97.             missing_files = set(remote_files.keys()) - self.get_local_files(model_dir)
  98.             if missing_files:
  99.                 print(f"Warning: Missing files in {model_id}: {missing_files}")
  100.             else:
  101.                 print(f"Successfully synchronized {model_id}")
  102.                
  103.         except Exception as e:
  104.             print(f"Failed to synchronize {model_id}: {str(e)}")
  105.  
  106.     def sync_all(self, max_workers: int = 3):
  107.         """Synchronize all repositories for the user."""
  108.         try:
  109.             print(f"Fetching models for user {self.username}...")
  110.             models = self.get_user_models()
  111.            
  112.             if not models:
  113.                 print(f"No models found for user {self.username}")
  114.                 return
  115.            
  116.             print(f"Found {len(models)} models")
  117.            
  118.             # Synchronize models in parallel
  119.             with ThreadPoolExecutor(max_workers=max_workers) as executor:
  120.                 executor.map(self.sync_repository, models)
  121.                
  122.             print(f"\nSynchronization complete! Models are in {self.output_dir}")
  123.            
  124.         except requests.exceptions.RequestException as e:
  125.             print(f"Error fetching models: {str(e)}")
  126.         except Exception as e:
  127.             print(f"An unexpected error occurred: {str(e)}")
  128.  
  129. def main():
  130.     import argparse
  131.    
  132.     parser = argparse.ArgumentParser(
  133.         description='Synchronize Hugging Face models for a specific user'
  134.     )
  135.     parser.add_argument('username', help='Hugging Face username')
  136.     parser.add_argument('--workers', '-w', type=int, default=3,
  137.                        help='Number of parallel downloads (default: 3)')
  138.     parser.add_argument('--output-dir', '-o', type=str,
  139.                        help='Custom output directory')
  140.    
  141.     args = parser.parse_args()
  142.    
  143.     # Check if huggingface-cli is installed
  144.     try:
  145.         subprocess.run(["huggingface-cli", "--version"], capture_output=True)
  146.     except FileNotFoundError:
  147.         print("Error: huggingface-cli is not installed.")
  148.         print("Please install it using: pip install --upgrade huggingface_hub")
  149.         sys.exit(1)
  150.    
  151.     downloader = HFDownloader(args.username, args.output_dir)
  152.     downloader.sync_all(max_workers=args.workers)
  153.  
  154. if __name__ == "__main__":
  155.     main()
Advertisement
Add Comment
Please, Sign In to add comment