Advertisement
Guest User

Untitled

a guest
Jan 27th, 2025
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 30.76 KB | Source Code | 0 0
  1. #!/usr/bin/env python3
  2. """
  3. Final script to analyze a Hugging Face LLM for feasibility and performance,
  4. using:
  5.  - Hugging Face Hub API for model param count + config data
  6.  - Fallback HTML scraping if param count is not in the metadata
  7.  - Fallback approximations for layer count, hidden dim if HF config is missing
  8.  
  9. It performs:
  10.  1. Model size detection (via HF API or scraping or --params).
  11.  2. Hardware detection (VRAM, RAM, bandwidth) or overrides.
  12.  3. Memory usage analysis for various quantizations.
  13.  4. If the model is "All in VRAM" or "KV cache offload," it calculates how many
  14.     tokens of KV cache fit into leftover VRAM, using real config if available
  15.     or a fallback approximation.
  16.  
  17. Author's Note:
  18. - This remains an approximation. The real architecture, layer shapes, and KV storage format
  19.  can differ from assumptions here.
  20. """
  21.  
  22. import argparse
  23. import math
  24. import os
  25. import platform
  26. import subprocess
  27. import logging
  28. import json
  29. import re
  30.  
  31. import requests
  32. import psutil
  33. from bs4 import BeautifulSoup
  34.  
  35. # Attempt to import huggingface_hub (optional)
  36. try:
  37.     from huggingface_hub import HfApi
  38.     HF_API_AVAILABLE = True
  39. except ImportError:
  40.     HF_API_AVAILABLE = False
  41.  
  42.  
  43. ###############################################################################
  44. # GLOBAL CONSTANTS
  45. ###############################################################################
  46.  
  47. # Mapping of quantization level to "bits per weight" for the model weights
  48. # (not necessarily for KV caches).
  49. QUANTIZATION_BPWS = {
  50.     "fp8": 8.0,
  51.     "q6_k_s": 6.6,
  52.     "q5_k_s": 5.5,
  53.     "q4_k_m": 4.8,
  54.     "IQ4_XS": 4.3,
  55.     "q3_k_m": 3.9,
  56.     "IQ3_XS": 3.3,
  57.     "IQ2_XS": 2.4
  58. }
  59.  
  60.  
  61. ###############################################################################
  62. # LOGGING
  63. ###############################################################################
  64. logger = logging.getLogger(__name__)
  65.  
  66.  
  67. ###############################################################################
  68. # ARGUMENT PARSING
  69. ###############################################################################
  70. def parse_args():
  71.     """
  72.    Parse command-line arguments.
  73.    """
  74.     parser = argparse.ArgumentParser(description="Analyze Hugging Face model with quantization.")
  75.  
  76.     parser.add_argument(
  77.         "model_id",
  78.         nargs="?",
  79.         default=None,
  80.         help="Hugging Face model ID (e.g., 'microsoft/phi-4'). If not provided, script attempts an interactive prompt."
  81.     )
  82.     parser.add_argument("-b", "--bandwidth", type=float, help="Override GPU bandwidth in GB/s.")
  83.     parser.add_argument("-n", "--num-gpus", type=int, default=1, help="Number of identical GPUs.")
  84.     parser.add_argument("-v", "--vram", type=float, help="Override detected VRAM in GB.")
  85.     parser.add_argument(
  86.         "--params",
  87.         default=None,
  88.         help="Manually specify model size (e.g. '13B', '350M') to skip scraping."
  89.     )
  90.     parser.add_argument(
  91.         "--output",
  92.         choices=["text", "json"],
  93.         default="text",
  94.         help="Output format: 'text' (default) or 'json'."
  95.     )
  96.     parser.add_argument(
  97.         "--quiet",
  98.         action="store_true",
  99.         help="Suppress non-critical logs."
  100.     )
  101.     parser.add_argument(
  102.         "--debug",
  103.         action="store_true",
  104.         help="Enable debug logging (overrides quiet)."
  105.     )
  106.     return parser.parse_args()
  107.  
  108.  
  109. ###############################################################################
  110. # MODEL PARAM RETRIEVAL
  111. ###############################################################################
  112.  
  113. def get_model_params_hfapi(model_id):
  114.     """
  115.    Use huggingface_hub to retrieve param info from model card metadata or tags.
  116.    Returns a string like "13B" or None if not found.
  117.    """
  118.     if not HF_API_AVAILABLE:
  119.         logger.debug("huggingface_hub not installed. Skipping HF API retrieval.")
  120.         return None
  121.  
  122.     logger.debug(f"Attempting HF API retrieval for model: {model_id}")
  123.     try:
  124.         api = HfApi()
  125.         model_info = api.model_info(repo_id=model_id)
  126.         logger.debug(f"HF API returned model_info: {model_info}")
  127.  
  128.         # 1) Check cardData for 'params' or 'model_size'
  129.         if hasattr(model_info, "cardData") and model_info.cardData:
  130.             logger.debug(f"cardData found: {model_info.cardData}")
  131.             for key in ["params", "model_size"]:
  132.                 if key in model_info.cardData:
  133.                     val = str(model_info.cardData[key])
  134.                     logger.debug(f"Found {key} in cardData: {val}")
  135.                     return val
  136.  
  137.         # 2) Check tags with a numeric pattern (e.g. "13B", "7.5B", "350M")
  138.         pattern = re.compile(r"^\d+(\.\d+)?[bm]$", re.IGNORECASE)
  139.         if hasattr(model_info, "tags") and model_info.tags:
  140.             logger.debug(f"tags found: {model_info.tags}")
  141.             for tag in model_info.tags:
  142.                 cleaned = tag.lower().replace("model-", "").strip()
  143.                 if pattern.match(cleaned):
  144.                     logger.debug(f"Found a numeric param tag: {tag}")
  145.                     return cleaned  # e.g. "7b"
  146.  
  147.         logger.debug("HF API used, but could not find param info in metadata.")
  148.     except Exception as e:
  149.         logger.warning(f"HF API error: {e}")
  150.  
  151.     return None
  152.  
  153.  
  154. def get_model_params_scrape(model_id):
  155.     """
  156.    Fallback method: Scrapes the huggingface.co/<model_id> page
  157.    and looks for a 'Model size: 13B params' chip. Returns e.g. "13B params".
  158.    """
  159.     url = f"https://huggingface.co/{model_id}"
  160.     logger.debug(f"Attempting web scrape for URL: {url}")
  161.     try:
  162.         response = requests.get(url, timeout=10)
  163.         logger.debug(f"Scrape status code: {response.status_code}")
  164.         if response.status_code != 200:
  165.             logger.warning(f"Could not access URL {url}. Status: {response.status_code}")
  166.             return None
  167.  
  168.         soup = BeautifulSoup(response.text, 'html.parser')
  169.         logger.debug("Looking for divs with class 'inline-flex h-6 shrink-0 items-center overflow-hidden rounded-lg border'")
  170.         param_divs = soup.find_all(
  171.             'div',
  172.             class_='inline-flex h-6 shrink-0 items-center overflow-hidden rounded-lg border'
  173.         )
  174.  
  175.         logger.debug(f"Found {len(param_divs)} div(s) matching that class.")
  176.         for i, div in enumerate(param_divs):
  177.             text_content = div.get_text(strip=True)
  178.             logger.debug(f"Div #{i} text: '{text_content}'")
  179.             if 'Model size' in text_content:
  180.                 sub_divs = div.find_all('div')
  181.                 logger.debug(f"Found 'Model size' in Div #{i}, sub_divs count={len(sub_divs)}")
  182.                 if len(sub_divs) > 1:
  183.                     size_text = sub_divs[1].text.strip()
  184.                     logger.debug(f"Extracted size_text: '{size_text}'")
  185.                     return size_text
  186.         logger.debug("No div with 'Model size' found using that HTML structure.")
  187.     except Exception as e:
  188.         logger.warning(f"Scraping error: {e}")
  189.  
  190.     return None
  191.  
  192.  
  193. def convert_params_to_b(size_text):
  194.     """
  195.    Convert strings like '13B params' or '7.5B' or '350M' => float param count.
  196.    E.g. '13B' => 1.3e10, '350M' => 3.5e8.
  197.    """
  198.     if not size_text:
  199.         logger.debug("convert_params_to_b received empty/None size_text.")
  200.         return None
  201.  
  202.     s = size_text.lower().replace("params", "").strip()
  203.     logger.debug(f"Converting size_text '{s}' to numeric param count.")
  204.  
  205.     if 'b' in s:
  206.         try:
  207.             val_str = s.replace('b', '').strip()
  208.             val = float(val_str)
  209.             logger.debug(f"Detected 'B' in string, returning {val * 1e9} as param count.")
  210.             return val * 1e9
  211.         except ValueError as e:
  212.             logger.debug(f"ValueError in convert_params_to_b: {e}")
  213.             return None
  214.     elif 'm' in s:
  215.         try:
  216.             val_str = s.replace('m', '').strip()
  217.             val = float(val_str)
  218.             logger.debug(f"Detected 'M' in string, returning {val * 1e6} as param count.")
  219.             return val * 1e6
  220.         except ValueError as e:
  221.             logger.debug(f"ValueError in convert_params_to_b: {e}")
  222.             return None
  223.  
  224.     logger.debug(f"No 'b' or 'm' in {s}, returning None.")
  225.     return None
  226.  
  227.  
  228. ###############################################################################
  229. # GET MODEL CONFIG FOR ACCURATE KV CACHE
  230. ###############################################################################
  231. def get_model_config_details(model_id):
  232.     """
  233.    If available, uses HF API to fetch model_info.config and extracts:
  234.      - num_layers  (could be "num_hidden_layers" or "n_layer")
  235.      - hidden_size (could be "hidden_size" or "n_embd")
  236.      - kv_bits     for the KV cache (assume 16 if not specified).
  237.  
  238.    Returns (num_layers, hidden_size, kv_bits) or (None, None, 16) if not found.
  239.  
  240.    This function won't do scraping. It only relies on the HF config fields.
  241.    If HF API is not installed or no config is found, returns None for layers/size.
  242.    """
  243.     if not HF_API_AVAILABLE:
  244.         logger.debug("HF API not available, cannot retrieve config details.")
  245.         return None, None, 16
  246.  
  247.     try:
  248.         api = HfApi()
  249.         model_info = api.model_info(repo_id=model_id)
  250.         cfg = getattr(model_info, "config", None)
  251.         if not cfg:
  252.             logger.debug("No config in model_info; returning fallback.")
  253.             return None, None, 16
  254.  
  255.         # Attempt to parse the standard config fields
  256.         logger.debug(f"HF config: {cfg}")
  257.         num_layers = cfg.get("num_hidden_layers") or cfg.get("n_layer")
  258.         hidden_size = cfg.get("hidden_size") or cfg.get("n_embd")
  259.  
  260.         # Some models might store additional metadata about KV precision.
  261.         # For now, we default to 16 bits if not found. Adjust as needed.
  262.         kv_bits = 16
  263.  
  264.         # If a field like "kv_precision" existed, you might parse it:
  265.         # kv_precision = cfg.get("kv_precision", "fp16").lower()
  266.         # if kv_precision == "fp32": kv_bits = 32
  267.         # ...
  268.  
  269.         return num_layers, hidden_size, kv_bits
  270.     except Exception as e:
  271.         logger.debug(f"Error retrieving model config details: {e}")
  272.         return None, None, 16
  273.  
  274.  
  275. ###############################################################################
  276. # SYSTEM MEMORY + BANDWIDTH DETECTION
  277. ###############################################################################
  278. def get_ram_specs():
  279.     total = psutil.virtual_memory().total / (1024**3)
  280.     logger.debug(f"Detected system RAM: {total:.2f} GB")
  281.     return total
  282.  
  283.  
  284. def get_memory_bandwidth():
  285.     try:
  286.         system = platform.system()
  287.         logger.debug(f"Platform for memory bandwidth detection: {system}")
  288.         if system == "Windows":
  289.             cmd = ["powershell", "-Command", "Get-CimInstance Win32_PhysicalMemory | Select-Object -ExpandProperty Speed"]
  290.             logger.debug(f"Running command: {' '.join(cmd)}")
  291.             try:
  292.                 output = subprocess.check_output(cmd, timeout=5).decode().strip().split("\n")
  293.                 logger.debug(f"Raw output for RAM speeds: {output}")
  294.                 speeds = [int(s) for s in output if s.isdigit()]
  295.                 if speeds:
  296.                     max_speed = max(speeds)
  297.                     bandwidth = max_speed * 8 * 2 / 1000
  298.                     logger.debug(f"Windows RAM max speed = {max_speed}, estimated bandwidth = {bandwidth:.2f} GB/s")
  299.                     return bandwidth
  300.                 logger.info("No known memory speeds found, defaulting to 48 GB/s.")
  301.                 return 48
  302.             except Exception as e:
  303.                 logger.warning(f"Windows memory detection error: {e}")
  304.                 return 48
  305.  
  306.         elif system == "Linux":
  307.             try:
  308.                 cmd = ["sudo", "dmidecode", "-t", "memory"]
  309.                 logger.debug(f"Running command: {' '.join(cmd)}")
  310.                 output = subprocess.check_output(cmd, timeout=5).decode().split("\n")
  311.                 logger.debug(f"Raw output of dmidecode: {output}")
  312.                 speeds = []
  313.                 for line in output:
  314.                     if "Speed:" in line and "Unknown" not in line:
  315.                         try:
  316.                             spd = line.split(":")[-1].strip().split(" ")[0]
  317.                             speeds.append(int(spd))
  318.                         except:
  319.                             pass
  320.                 if speeds:
  321.                     max_speed = max(speeds)
  322.                     bandwidth = max_speed * 8 * 2 / 1000
  323.                     logger.debug(f"Linux RAM max speed = {max_speed}, estimated bandwidth = {bandwidth:.2f} GB/s")
  324.                     return bandwidth
  325.             except Exception as e:
  326.                 logger.warning(f"dmidecode call failed or timed out: {e}")
  327.                 pass
  328.  
  329.             logger.debug("Falling back to /proc/meminfo-based heuristic.")
  330.             with open('/proc/meminfo', 'r') as f:
  331.                 mem_info = f.read().lower()
  332.                 if 'memtotal' in mem_info:
  333.                     total_kb = int(mem_info.split('memtotal:')[1].split('kb')[0].strip())
  334.                     total_gb = total_kb / (1024**2)
  335.                     if total_gb >= 32:
  336.                         return 64
  337.                     else:
  338.                         return 48
  339.         # Default if other OS
  340.         logger.info("Unsupported platform for memory detection, defaulting to 48 GB/s.")
  341.         return 48
  342.  
  343.     except Exception as e:
  344.         logger.error(f"Error retrieving RAM speed: {e}")
  345.         return 48
  346.  
  347.  
  348. ###############################################################################
  349. # GPU VRAM + BANDWIDTH DETECTION
  350. ###############################################################################
  351. def get_vram_specs():
  352.     vram = None
  353.     bandwidth = None
  354.  
  355.     system = platform.system()
  356.     logger.debug(f"Platform for VRAM detection: {system}")
  357.  
  358.     if system == "Windows":
  359.         vram = detect_vram_windows()
  360.     elif system == "Linux":
  361.         vram = detect_vram_linux()
  362.     else:
  363.         logger.warning("Unsupported platform for VRAM detection; defaulting VRAM=0.")
  364.         vram = 0
  365.  
  366.     logger.debug(f"Detected VRAM before bandwidth assignment: {vram}")
  367.  
  368.     if vram is not None and vram > 0:
  369.         # Heuristic bandwidth assignment
  370.         if vram >= 49:
  371.             bandwidth = 1500
  372.         elif vram >= 25:
  373.             bandwidth = 1790
  374.         elif vram >= 17:
  375.             bandwidth = 950
  376.         elif vram >= 13:
  377.             bandwidth = 550
  378.         elif vram >= 9:
  379.             bandwidth = 400
  380.         elif vram >= 7:
  381.             bandwidth = 300
  382.         elif vram >= 5:
  383.             bandwidth = 240
  384.         else:
  385.             bandwidth = 200
  386.         logger.debug(f"Heuristic GPU bandwidth assigned: {bandwidth} GB/s")
  387.     else:
  388.         logger.warning("VRAM not detected, defaulting to 0 GB and 0 GB/s bandwidth.")
  389.         vram = 0
  390.         bandwidth = 0
  391.  
  392.     return vram, bandwidth
  393.  
  394.  
  395. def detect_vram_windows():
  396.     logger.debug("Attempting NVIDIA VRAM detection via nvidia-smi.")
  397.     try:
  398.         cmd = ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"]
  399.         logger.debug(f"Running command: {' '.join(cmd)}")
  400.         output = subprocess.check_output(cmd, timeout=5).decode().strip()
  401.         logger.debug(f"nvidia-smi output:\n{output}")
  402.         lines = output.split("\n")
  403.         if lines:
  404.             max_mem_mb = max(float(x) for x in lines if x.strip())
  405.             return max_mem_mb / 1024
  406.     except Exception as e:
  407.         logger.debug(f"Failed nvidia-smi detection: {e}")
  408.  
  409.     logger.debug("Attempting AMD/Intel VRAM detection via PowerShell WMI.")
  410.     try:
  411.         cmd = ["powershell", "-Command", "Get-WmiObject Win32_VideoController | Select-Object AdapterRAM"]
  412.         logger.debug(f"Running command: {' '.join(cmd)}")
  413.         output = subprocess.check_output(cmd, timeout=5).decode().strip()
  414.         logger.debug(f"WMI output for AdapterRAM:\n{output}")
  415.         for line in output.split('\n'):
  416.             line = line.strip()
  417.             if line.isdigit():
  418.                 vram_gb = int(line) / (1024**3)
  419.                 logger.debug(f"Detected VRAM from WMI: {vram_gb:.2f} GB")
  420.                 return vram_gb
  421.     except Exception as e:
  422.         logger.debug(f"Failed AMD VRAM detection via WMI: {e}")
  423.  
  424.     logger.debug("Checking for Intel Arc via WMI Description.")
  425.     try:
  426.         cmd = ["powershell", "-Command", "Get-WmiObject Win32_VideoController | Select-Object Description"]
  427.         logger.debug(f"Running command: {' '.join(cmd)}")
  428.         output = subprocess.check_output(cmd, timeout=5).decode().lower()
  429.         logger.debug(f"WMI Description output:\n{output}")
  430.         if 'intel' in output and 'arc' in output:
  431.             if 'a770' in output:
  432.                 return 16
  433.             elif 'b580' in output:
  434.                 return 12
  435.             elif 'b570' in output:
  436.                 return 10
  437.             elif 'a750' in output:
  438.                 return 8
  439.             elif 'a380' in output:
  440.                 return 6
  441.             elif 'a310' in output:
  442.                 return 4
  443.     except Exception as e:
  444.         logger.debug(f"Failed Intel Arc detection via WMI: {e}")
  445.  
  446.     return None
  447.  
  448.  
  449. def detect_vram_linux():
  450.     logger.debug("Attempting NVIDIA VRAM detection via nvidia-smi on Linux.")
  451.     try:
  452.         cmd = ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"]
  453.         logger.debug(f"Running command: {' '.join(cmd)}")
  454.         output = subprocess.check_output(cmd, timeout=5).decode().strip()
  455.         logger.debug(f"nvidia-smi output:\n{output}")
  456.         lines = output.split("\n")
  457.         if lines:
  458.             max_mem_mb = max(float(x) for x in lines if x.strip())
  459.             return max_mem_mb / 1024
  460.     except Exception as e:
  461.         logger.debug(f"Failed nvidia-smi detection: {e}")
  462.  
  463.     logger.debug("Attempting AMD VRAM detection via /sys path.")
  464.     try:
  465.         amd_paths = [
  466.             "/sys/class/drm/card0/device/mem_info_vram_total",
  467.             "/sys/class/gpu/card0/device/mem_info_vram_total"
  468.         ]
  469.         for path in amd_paths:
  470.             if os.path.exists(path):
  471.                 with open(path, 'r') as f:
  472.                     vram_bytes = int(f.read().strip())
  473.                 vram_gb = vram_bytes / (1024**3)
  474.                 logger.debug(f"Detected AMD VRAM from {path}: {vram_gb:.2f} GB")
  475.                 return vram_gb
  476.     except Exception as e:
  477.         logger.debug(f"Failed AMD VRAM detection via /sys: {e}")
  478.  
  479.     logger.debug("Attempting Intel Arc detection via lspci.")
  480.     try:
  481.         cmd = ["lspci", "-v"]
  482.         logger.debug(f"Running command: {' '.join(cmd)}")
  483.         output = subprocess.check_output(cmd, timeout=5).decode().lower()
  484.         logger.debug(f"lspci output:\n{output}")
  485.         if 'intel' in output and 'arc' in output:
  486.             if 'a770' in output:
  487.                 return 16
  488.             elif 'b580' in output:
  489.                 return 12
  490.             elif 'b570' in output:
  491.                 return 10
  492.             elif 'a750' in output:
  493.                 return 8
  494.             elif 'a380' in output:
  495.                 return 6
  496.             elif 'a310' in output:
  497.                 return 4
  498.     except Exception as e:
  499.         logger.debug(f"Failed Intel Arc detection via lspci: {e}")
  500.  
  501.     return None
  502.  
  503.  
  504. ###############################################################################
  505. # KV CACHE CALCULATION WITH REAL CONFIG (FALLBACK if MISSING)
  506. ###############################################################################
  507. def estimate_max_context_size(
  508.     leftover_vram_gb, num_layers, hidden_size, kv_bits=16
  509. ):
  510.     """
  511.    Uses real config to compute memory usage for KV cache:
  512.      memory_per_token (bytes) = 2 * num_layers * hidden_size * (kv_bits/8)
  513.      leftover_vram_bytes = leftover_vram_gb * 1e9
  514.      => max_context = leftover_vram_bytes / memory_per_token
  515.  
  516.    If either num_layers or hidden_size is missing, we return 0 (let a fallback logic handle it).
  517.    """
  518.     if leftover_vram_gb <= 0:
  519.         return 0
  520.     if not num_layers or not hidden_size:
  521.         # can't do a real formula
  522.         return 0
  523.  
  524.     leftover_bytes = leftover_vram_gb * 1e9
  525.     mem_per_token = 2.0 * num_layers * hidden_size * (kv_bits / 8.0)
  526.     max_context = leftover_bytes / mem_per_token
  527.     logger.debug(
  528.         f"Using real config for KV: L={num_layers}, hidden_size={hidden_size}, kv_bits={kv_bits}, leftover={leftover_vram_gb:.2f} GB => {max_context:.1f} tokens"
  529.     )
  530.     return int(max_context)
  531.  
  532.  
  533. def estimate_max_context_size_fallback(params_b, leftover_vram_gb, bpw):
  534.     """
  535.    Fallback: guess layer count from param size, guess hidden dim from sqrt,
  536.    then assume kv_bits=16. Same logic as the older approximate method.
  537.  
  538.    If leftover_vram_gb <= 0 => 0
  539.    """
  540.     if leftover_vram_gb <= 0:
  541.         return 0
  542.  
  543.     # guess layer count
  544.     if params_b > 30e9:
  545.         L = 60
  546.     elif params_b > 10e9:
  547.         L = 40
  548.     elif params_b > 5e9:
  549.         L = 32
  550.     else:
  551.         L = 28
  552.  
  553.     # approximate d_model from param_count
  554.     # param_count ~ 2 * L * d_model^2 => d_model ~ sqrt(params_b/(2*L))
  555.     d_model = math.sqrt(params_b / (2.0 * L))
  556.  
  557.     # we assume KV uses 16 bits even if bpw < 16
  558.     kv_bits = 16
  559.     leftover_bytes = leftover_vram_gb * 1e9
  560.     mem_per_token = 2.0 * L * d_model * (kv_bits / 8.0)
  561.     max_context = leftover_bytes / mem_per_token
  562.  
  563.     logger.debug(f"Fallback KV calc => L={L}, d_model={d_model:.1f}, leftover={leftover_vram_gb:.2f} => {max_context:.1f} tokens")
  564.     return int(max_context)
  565.  
  566.  
  567. ###############################################################################
  568. # THROUGHPUT ESTIMATION
  569. ###############################################################################
  570. def estimate_tks(ram_bandwidth, required_mem):
  571.     logger.debug(f"Estimating tk/s with ram_bandwidth={ram_bandwidth:.2f} GB/s, required_mem={required_mem:.2f} GB.")
  572.     return (ram_bandwidth / required_mem) * 0.9
  573.  
  574.  
  575. def calculate_tks(base_tks, offload_ratio):
  576.     logger.debug(f"Calculating partial offload tk/s with base_tks={base_tks:.2f}, offload_ratio={offload_ratio:.2f}%.")
  577.     return base_tks * (0.052 * math.exp(4.55 * (100 - offload_ratio) / 100) + 1.06)
  578.  
  579.  
  580. ###############################################################################
  581. # ANALYSIS
  582. ###############################################################################
  583. def analyze_quantization(params_b, vram_gb, bandwidth, ram_gb, quant, bpw, ram_bandwidth,
  584.                          num_layers, hidden_size, kv_bits=16):
  585.     """
  586.    For a given quantization (bpw), compute:
  587.      - total memory required for weights (required_mem)
  588.      - run type (all in VRAM, partial offload, etc.)
  589.      - tokens/s (heuristic)
  590.      - leftover VRAM => max_context tokens if run_type is "All in VRAM" or "KV cache offload"
  591.  
  592.    Returns (run_type, required_mem, offload_ratio, tks, max_ctx).
  593.    """
  594.     logger.debug(f"Analyzing quant={quant}, bpw={bpw}, params_b={params_b}, vram_gb={vram_gb}, bandwidth={bandwidth}")
  595.     required_base = 1.0 + params_b * 0.05 / 1e9  # overhead in GB
  596.     required_mem = required_base + (params_b * bpw / 8.0 / 1e9)
  597.     logger.debug(f"Computed required_mem={required_mem:.2f} GB for quant={quant}.")
  598.  
  599.     run_type = "Won't run"
  600.     tks = None
  601.     offload_ratio = 0
  602.     max_ctx = 0
  603.  
  604.     if required_mem <= vram_gb:
  605.         run_type = "All in VRAM"
  606.         tks = bandwidth / required_mem
  607.         leftover_vram_gb = vram_gb - required_mem
  608.         # First try real config-based formula
  609.         if num_layers and hidden_size:
  610.             max_ctx = estimate_max_context_size(leftover_vram_gb, num_layers, hidden_size, kv_bits=kv_bits)
  611.         else:
  612.             # fallback
  613.             max_ctx = estimate_max_context_size_fallback(params_b, leftover_vram_gb, bpw)
  614.  
  615.     elif required_mem <= vram_gb + 1:
  616.         run_type = "KV cache offload"
  617.         tks = (bandwidth / required_mem) * 0.9
  618.         leftover_vram_gb = (vram_gb + 1) - required_mem
  619.         if leftover_vram_gb < 0:
  620.             leftover_vram_gb = 0
  621.         if num_layers and hidden_size:
  622.             max_ctx = estimate_max_context_size(leftover_vram_gb, num_layers, hidden_size, kv_bits=kv_bits)
  623.         else:
  624.             max_ctx = estimate_max_context_size_fallback(params_b, leftover_vram_gb, bpw)
  625.  
  626.     elif vram_gb > 1 and required_mem <= (ram_gb + vram_gb):
  627.         run_type = "Partial offload"
  628.         offload_ratio = (required_mem - vram_gb) / required_mem * 100
  629.         base_tks = estimate_tks(ram_bandwidth, required_mem)
  630.         tks = calculate_tks(base_tks, offload_ratio)
  631.         max_ctx = 0  # Not a straightforward leftover VRAM scenario
  632.  
  633.     elif required_mem <= ram_gb:
  634.         run_type = "All in System RAM"
  635.         offload_ratio = 100
  636.         base_tks = estimate_tks(ram_bandwidth, required_mem)
  637.         tks = base_tks
  638.         max_ctx = 0
  639.  
  640.     return run_type, required_mem, offload_ratio, tks, max_ctx
  641.  
  642.  
  643. def analyze_all_quantizations(params_b, vram_gb, bandwidth, ram_gb, ram_bandwidth,
  644.                               num_layers, hidden_size, kv_bits):
  645.     logger.debug("Analyzing all quantizations.")
  646.     results = {}
  647.     for quant, bpw in QUANTIZATION_BPWS.items():
  648.         run_type, mem_usage, offload_ratio, tks, max_ctx = analyze_quantization(
  649.             params_b, vram_gb, bandwidth, ram_gb, quant, bpw, ram_bandwidth,
  650.             num_layers, hidden_size, kv_bits
  651.         )
  652.         results[quant] = {
  653.             "run_type": run_type,
  654.             "memory_required_gb": mem_usage,
  655.             "offload_percentage": offload_ratio,
  656.             "tokens_per_s": tks,
  657.             "max_context_tokens": max_ctx
  658.         }
  659.     return results
  660.  
  661.  
  662. ###############################################################################
  663. # MAIN
  664. ###############################################################################
  665. def main():
  666.     args = parse_args()
  667.  
  668.     # Logging
  669.     if args.debug:
  670.         logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
  671.         logger.setLevel(logging.DEBUG)
  672.     else:
  673.         logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)
  674.         if args.quiet:
  675.             logger.setLevel(logging.WARNING)
  676.  
  677.     logger.debug(f"Parsed arguments: {args}")
  678.  
  679.     # Model ID
  680.     model_id = args.model_id
  681.     if not model_id:
  682.         model_id = input("Enter Hugging Face model ID (e.g., microsoft/phi-4): ")
  683.         logger.info(f"Using user-provided model_id: {model_id}")
  684.  
  685.     # 1) Retrieve model params
  686.     if args.params:
  687.         param_text = args.params
  688.         logger.debug(f"User provided params: {param_text}")
  689.     else:
  690.         param_text = get_model_params_hfapi(model_id)
  691.         if not param_text:
  692.             param_text = get_model_params_scrape(model_id)
  693.  
  694.     if not param_text:
  695.         logger.error("Could not determine model parameters from HF API or page.")
  696.         output_results(None, None, None, None, None, None, args.output)
  697.         return
  698.  
  699.     # 2) Convert param string to numeric param count
  700.     params_b = convert_params_to_b(param_text)
  701.     if not params_b:
  702.         logger.error(f"Failed to parse parameter string into a numeric value: '{param_text}'")
  703.         output_results(None, None, None, None, None, None, args.output)
  704.         return
  705.  
  706.     # 3) Retrieve config for real KV calculation if available
  707.     num_layers, hidden_size, kv_bits = get_model_config_details(model_id)
  708.  
  709.     # 4) System detection (RAM, VRAM, Bandwidth)
  710.     total_ram = get_ram_specs()
  711.     vram, bandwidth = get_vram_specs()
  712.  
  713.     logger.debug(f"Detected VRAM = {vram} GB, GPU bandwidth = {bandwidth} GB/s")
  714.  
  715.     if args.vram is not None:
  716.         logger.info(f"Overriding detected VRAM with user-supplied value: {args.vram} GB")
  717.         vram = args.vram
  718.  
  719.     if not bandwidth or bandwidth == 0:
  720.         logger.warning("GPU bandwidth not detected or is 0; defaulting to 200 GB/s.")
  721.         bandwidth = 200
  722.  
  723.     if args.bandwidth is not None:
  724.         logger.info(f"Overriding GPU bandwidth with user-supplied value: {args.bandwidth} GB/s")
  725.         bandwidth = args.bandwidth
  726.  
  727.     # Handle multi-GPU
  728.     if args.num_gpus > 1:
  729.         vram_total = vram * args.num_gpus
  730.         bandwidth_total = (bandwidth * args.num_gpus) * 0.42
  731.         logger.debug(f"Multi-GPU scenario => vram_total={vram_total}, bandwidth_total={bandwidth_total}")
  732.         vram = vram_total
  733.         bandwidth = bandwidth_total
  734.  
  735.     ram_bandwidth = get_memory_bandwidth()
  736.  
  737.     # 5) Analyze
  738.     results = analyze_all_quantizations(params_b, vram, bandwidth, total_ram, ram_bandwidth,
  739.                                         num_layers, hidden_size, kv_bits)
  740.  
  741.     # Output
  742.     output_results(
  743.         model_id,
  744.         param_text,
  745.         params_b,
  746.         vram,
  747.         bandwidth,
  748.         results,
  749.         args.output,
  750.         total_ram=total_ram,
  751.         ram_bandwidth=ram_bandwidth
  752.     )
  753.  
  754.  
  755. def output_results(model_id, param_text, params_b, vram, bandwidth, results, mode,
  756.                    total_ram=None, ram_bandwidth=None):
  757.     if not results:
  758.         if mode == "json":
  759.             print(json.dumps({"error": "No results available"}, indent=2))
  760.         else:
  761.             print("No results available.")
  762.         return
  763.  
  764.     summary = {
  765.         "model_id": model_id,
  766.         "param_text": param_text,
  767.         "params_count": params_b,
  768.         "system_ram_gb": total_ram,
  769.         "vram_gb": vram,
  770.         "gpu_bandwidth_gb_s": bandwidth,
  771.         "ram_bandwidth_gb_s": ram_bandwidth,
  772.         "quantization_analysis": results
  773.     }
  774.  
  775.     if mode == "json":
  776.         print(json.dumps(summary, indent=2))
  777.     else:
  778.         print(f"Model: {model_id}")
  779.         print(f"Model Parameters (raw): {param_text}")
  780.         if params_b:
  781.             print(f"Converted Param Count: {params_b / 1e9:.2f}B parameters")
  782.         print(f"System RAM: {total_ram:.2f} GB")
  783.         print(f"Detected VRAM: {vram:.2f} GB")
  784.         print(f"GPU Bandwidth (approx): {bandwidth:.2f} GB/s")
  785.         print(f"System RAM Bandwidth (approx): {ram_bandwidth:.2f} GB/s")
  786.  
  787.         print("\nAnalysis per quantization level:")
  788.         for quant, data in results.items():
  789.             run_type = data["run_type"]
  790.             mem_req = data["memory_required_gb"]
  791.             offload_pct = data["offload_percentage"]
  792.             tks = data["tokens_per_s"]
  793.             max_ctx = data["max_context_tokens"]
  794.  
  795.             print(f"\nQuantization: {quant.upper()}")
  796.             print(f"  - Run Type: {run_type}")
  797.             print(f"  - Memory Required: {mem_req:.2f} GB")
  798.             if offload_pct > 0:
  799.                 print(f"  - GPU Offload Percentage: {100 - offload_pct:.1f}% in GPU")
  800.             if tks:
  801.                 print(f"  - Estimated tk/s: {tks:.2f}")
  802.             if max_ctx > 0:
  803.                 print(f"  - Estimated Max Context Size (tokens) in leftover VRAM: {max_ctx}")
  804.             else:
  805.                 print("  - Estimated Max Context Size (tokens): 0 (N/A or partial offload)")
  806.  
  807.  
  808. if __name__ == "__main__":
  809.     main()
  810.  
Tags: python
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement