Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python3
- import os
- import sys
- import subprocess
- import requests
- import hashlib
- from concurrent.futures import ThreadPoolExecutor
- from pathlib import Path
- from typing import List, Optional, Set, Dict
- import shutil
- import json
- class HFDownloader:
- def __init__(self, username: str, output_dir: Optional[str] = None, force: bool = False):
- self.username = username
- self.output_dir = output_dir or Path(f"hf_downloads/{username}")
- self.api_url = "https://huggingface.co/api/models"
- self.force = force
- def get_user_models(self) -> List[str]:
- """Fetch all model IDs for a given username."""
- params = {"author": self.username}
- response = requests.get(self.api_url, params=params)
- response.raise_for_status()
- return [model["id"] for model in response.json()]
- def get_remote_files(self, model_id: str) -> Dict[str, str]:
- """Get list of files and their SHA values from remote repository."""
- try:
- url = f"https://huggingface.co/api/models/{model_id}/tree/main"
- response = requests.get(url)
- response.raise_for_status()
- files = {}
- for item in response.json():
- if 'path' in item and 'lfs' in item:
- files[item['path']] = item['lfs'].get('sha256') if item['lfs'] else None
- elif 'path' in item and 'oid' in item:
- files[item['path']] = item['oid']
- return files
- except Exception as e:
- print(f"Error fetching remote files for {model_id}: {str(e)}")
- return {}
- def get_local_files(self, model_dir: Path) -> Set[str]:
- """Get list of all files in local repository."""
- if not model_dir.exists():
- return set()
- files = set()
- for root, _, filenames in os.walk(model_dir):
- for filename in filenames:
- rel_path = os.path.relpath(os.path.join(root, filename), model_dir)
- files.add(rel_path)
- return files
- def sync_repository(self, model_id: str) -> None:
- """Synchronize local repository with remote, ensuring exact match."""
- try:
- model_dir = Path(self.output_dir) / model_id.split("/")[-1]
- # Get remote files and their hashes
- print(f"Checking remote structure for {model_id}...")
- remote_files = self.get_remote_files(model_id)
- if not remote_files:
- print(f"Error: Could not fetch remote file list for {model_id}")
- return
- # Get local files
- local_files = self.get_local_files(model_dir)
- # Find extra files to remove
- extra_files = local_files - set(remote_files.keys())
- if extra_files:
- print(f"Removing {len(extra_files)} extra files from {model_id}")
- for file in extra_files:
- (model_dir / file).unlink()
- # Download/update repository
- print(f"Synchronizing {model_id}...")
- cmd = [
- "huggingface-cli", "download",
- "--resume-download",
- "--local-dir", str(model_dir),
- "--clean-cache", # Clean the cache to ensure fresh download
- model_id
- ]
- result = subprocess.run(cmd, capture_output=True, text=True)
- if result.returncode != 0:
- print(f"Error synchronizing {model_id}: {result.stderr}")
- return
- # Verify all files are present and have correct hashes
- missing_files = set(remote_files.keys()) - self.get_local_files(model_dir)
- if missing_files:
- print(f"Warning: Missing files in {model_id}: {missing_files}")
- else:
- print(f"Successfully synchronized {model_id}")
- except Exception as e:
- print(f"Failed to synchronize {model_id}: {str(e)}")
- def sync_all(self, max_workers: int = 3):
- """Synchronize all repositories for the user."""
- try:
- print(f"Fetching models for user {self.username}...")
- models = self.get_user_models()
- if not models:
- print(f"No models found for user {self.username}")
- return
- print(f"Found {len(models)} models")
- # Synchronize models in parallel
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- executor.map(self.sync_repository, models)
- print(f"\nSynchronization complete! Models are in {self.output_dir}")
- except requests.exceptions.RequestException as e:
- print(f"Error fetching models: {str(e)}")
- except Exception as e:
- print(f"An unexpected error occurred: {str(e)}")
- def main():
- import argparse
- parser = argparse.ArgumentParser(
- description='Synchronize Hugging Face models for a specific user'
- )
- parser.add_argument('username', help='Hugging Face username')
- parser.add_argument('--workers', '-w', type=int, default=3,
- help='Number of parallel downloads (default: 3)')
- parser.add_argument('--output-dir', '-o', type=str,
- help='Custom output directory')
- args = parser.parse_args()
- # Check if huggingface-cli is installed
- try:
- subprocess.run(["huggingface-cli", "--version"], capture_output=True)
- except FileNotFoundError:
- print("Error: huggingface-cli is not installed.")
- print("Please install it using: pip install --upgrade huggingface_hub")
- sys.exit(1)
- downloader = HFDownloader(args.username, args.output_dir)
- downloader.sync_all(max_workers=args.workers)
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment