Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python3
- """
- Final script to analyze a Hugging Face LLM for feasibility and performance,
- using:
- - Hugging Face Hub API for model param count + config data
- - Fallback HTML scraping if param count is not in the metadata
- - Fallback approximations for layer count, hidden dim if HF config is missing
- It performs:
- 1. Model size detection (via HF API or scraping or --params).
- 2. Hardware detection (VRAM, RAM, bandwidth) or overrides.
- 3. Memory usage analysis for various quantizations.
- 4. If the model is "All in VRAM" or "KV cache offload," it calculates how many
- tokens of KV cache fit into leftover VRAM, using real config if available
- or a fallback approximation.
- Author's Note:
- - This remains an approximation. The real architecture, layer shapes, and KV storage format
- can differ from assumptions here.
- """
- import argparse
- import math
- import os
- import platform
- import subprocess
- import logging
- import json
- import re
- import requests
- import psutil
- from bs4 import BeautifulSoup
- # Attempt to import huggingface_hub (optional)
- try:
- from huggingface_hub import HfApi
- HF_API_AVAILABLE = True
- except ImportError:
- HF_API_AVAILABLE = False
- ###############################################################################
- # GLOBAL CONSTANTS
- ###############################################################################
- # Mapping of quantization level to "bits per weight" for the model weights
- # (not necessarily for KV caches).
- QUANTIZATION_BPWS = {
- "fp8": 8.0,
- "q6_k_s": 6.6,
- "q5_k_s": 5.5,
- "q4_k_m": 4.8,
- "IQ4_XS": 4.3,
- "q3_k_m": 3.9,
- "IQ3_XS": 3.3,
- "IQ2_XS": 2.4
- }
- ###############################################################################
- # LOGGING
- ###############################################################################
- logger = logging.getLogger(__name__)
- ###############################################################################
- # ARGUMENT PARSING
- ###############################################################################
- def parse_args():
- """
- Parse command-line arguments.
- """
- parser = argparse.ArgumentParser(description="Analyze Hugging Face model with quantization.")
- parser.add_argument(
- "model_id",
- nargs="?",
- default=None,
- help="Hugging Face model ID (e.g., 'microsoft/phi-4'). If not provided, script attempts an interactive prompt."
- )
- parser.add_argument("-b", "--bandwidth", type=float, help="Override GPU bandwidth in GB/s.")
- parser.add_argument("-n", "--num-gpus", type=int, default=1, help="Number of identical GPUs.")
- parser.add_argument("-v", "--vram", type=float, help="Override detected VRAM in GB.")
- parser.add_argument(
- "--params",
- default=None,
- help="Manually specify model size (e.g. '13B', '350M') to skip scraping."
- )
- parser.add_argument(
- "--output",
- choices=["text", "json"],
- default="text",
- help="Output format: 'text' (default) or 'json'."
- )
- parser.add_argument(
- "--quiet",
- action="store_true",
- help="Suppress non-critical logs."
- )
- parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable debug logging (overrides quiet)."
- )
- return parser.parse_args()
- ###############################################################################
- # MODEL PARAM RETRIEVAL
- ###############################################################################
- def get_model_params_hfapi(model_id):
- """
- Use huggingface_hub to retrieve param info from model card metadata or tags.
- Returns a string like "13B" or None if not found.
- """
- if not HF_API_AVAILABLE:
- logger.debug("huggingface_hub not installed. Skipping HF API retrieval.")
- return None
- logger.debug(f"Attempting HF API retrieval for model: {model_id}")
- try:
- api = HfApi()
- model_info = api.model_info(repo_id=model_id)
- logger.debug(f"HF API returned model_info: {model_info}")
- # 1) Check cardData for 'params' or 'model_size'
- if hasattr(model_info, "cardData") and model_info.cardData:
- logger.debug(f"cardData found: {model_info.cardData}")
- for key in ["params", "model_size"]:
- if key in model_info.cardData:
- val = str(model_info.cardData[key])
- logger.debug(f"Found {key} in cardData: {val}")
- return val
- # 2) Check tags with a numeric pattern (e.g. "13B", "7.5B", "350M")
- pattern = re.compile(r"^\d+(\.\d+)?[bm]$", re.IGNORECASE)
- if hasattr(model_info, "tags") and model_info.tags:
- logger.debug(f"tags found: {model_info.tags}")
- for tag in model_info.tags:
- cleaned = tag.lower().replace("model-", "").strip()
- if pattern.match(cleaned):
- logger.debug(f"Found a numeric param tag: {tag}")
- return cleaned # e.g. "7b"
- logger.debug("HF API used, but could not find param info in metadata.")
- except Exception as e:
- logger.warning(f"HF API error: {e}")
- return None
- def get_model_params_scrape(model_id):
- """
- Fallback method: Scrapes the huggingface.co/<model_id> page
- and looks for a 'Model size: 13B params' chip. Returns e.g. "13B params".
- """
- url = f"https://huggingface.co/{model_id}"
- logger.debug(f"Attempting web scrape for URL: {url}")
- try:
- response = requests.get(url, timeout=10)
- logger.debug(f"Scrape status code: {response.status_code}")
- if response.status_code != 200:
- logger.warning(f"Could not access URL {url}. Status: {response.status_code}")
- return None
- soup = BeautifulSoup(response.text, 'html.parser')
- logger.debug("Looking for divs with class 'inline-flex h-6 shrink-0 items-center overflow-hidden rounded-lg border'")
- param_divs = soup.find_all(
- 'div',
- class_='inline-flex h-6 shrink-0 items-center overflow-hidden rounded-lg border'
- )
- logger.debug(f"Found {len(param_divs)} div(s) matching that class.")
- for i, div in enumerate(param_divs):
- text_content = div.get_text(strip=True)
- logger.debug(f"Div #{i} text: '{text_content}'")
- if 'Model size' in text_content:
- sub_divs = div.find_all('div')
- logger.debug(f"Found 'Model size' in Div #{i}, sub_divs count={len(sub_divs)}")
- if len(sub_divs) > 1:
- size_text = sub_divs[1].text.strip()
- logger.debug(f"Extracted size_text: '{size_text}'")
- return size_text
- logger.debug("No div with 'Model size' found using that HTML structure.")
- except Exception as e:
- logger.warning(f"Scraping error: {e}")
- return None
- def convert_params_to_b(size_text):
- """
- Convert strings like '13B params' or '7.5B' or '350M' => float param count.
- E.g. '13B' => 1.3e10, '350M' => 3.5e8.
- """
- if not size_text:
- logger.debug("convert_params_to_b received empty/None size_text.")
- return None
- s = size_text.lower().replace("params", "").strip()
- logger.debug(f"Converting size_text '{s}' to numeric param count.")
- if 'b' in s:
- try:
- val_str = s.replace('b', '').strip()
- val = float(val_str)
- logger.debug(f"Detected 'B' in string, returning {val * 1e9} as param count.")
- return val * 1e9
- except ValueError as e:
- logger.debug(f"ValueError in convert_params_to_b: {e}")
- return None
- elif 'm' in s:
- try:
- val_str = s.replace('m', '').strip()
- val = float(val_str)
- logger.debug(f"Detected 'M' in string, returning {val * 1e6} as param count.")
- return val * 1e6
- except ValueError as e:
- logger.debug(f"ValueError in convert_params_to_b: {e}")
- return None
- logger.debug(f"No 'b' or 'm' in {s}, returning None.")
- return None
- ###############################################################################
- # GET MODEL CONFIG FOR ACCURATE KV CACHE
- ###############################################################################
- def get_model_config_details(model_id):
- """
- If available, uses HF API to fetch model_info.config and extracts:
- - num_layers (could be "num_hidden_layers" or "n_layer")
- - hidden_size (could be "hidden_size" or "n_embd")
- - kv_bits for the KV cache (assume 16 if not specified).
- Returns (num_layers, hidden_size, kv_bits) or (None, None, 16) if not found.
- This function won't do scraping. It only relies on the HF config fields.
- If HF API is not installed or no config is found, returns None for layers/size.
- """
- if not HF_API_AVAILABLE:
- logger.debug("HF API not available, cannot retrieve config details.")
- return None, None, 16
- try:
- api = HfApi()
- model_info = api.model_info(repo_id=model_id)
- cfg = getattr(model_info, "config", None)
- if not cfg:
- logger.debug("No config in model_info; returning fallback.")
- return None, None, 16
- # Attempt to parse the standard config fields
- logger.debug(f"HF config: {cfg}")
- num_layers = cfg.get("num_hidden_layers") or cfg.get("n_layer")
- hidden_size = cfg.get("hidden_size") or cfg.get("n_embd")
- # Some models might store additional metadata about KV precision.
- # For now, we default to 16 bits if not found. Adjust as needed.
- kv_bits = 16
- # If a field like "kv_precision" existed, you might parse it:
- # kv_precision = cfg.get("kv_precision", "fp16").lower()
- # if kv_precision == "fp32": kv_bits = 32
- # ...
- return num_layers, hidden_size, kv_bits
- except Exception as e:
- logger.debug(f"Error retrieving model config details: {e}")
- return None, None, 16
- ###############################################################################
- # SYSTEM MEMORY + BANDWIDTH DETECTION
- ###############################################################################
- def get_ram_specs():
- total = psutil.virtual_memory().total / (1024**3)
- logger.debug(f"Detected system RAM: {total:.2f} GB")
- return total
- def get_memory_bandwidth():
- try:
- system = platform.system()
- logger.debug(f"Platform for memory bandwidth detection: {system}")
- if system == "Windows":
- cmd = ["powershell", "-Command", "Get-CimInstance Win32_PhysicalMemory | Select-Object -ExpandProperty Speed"]
- logger.debug(f"Running command: {' '.join(cmd)}")
- try:
- output = subprocess.check_output(cmd, timeout=5).decode().strip().split("\n")
- logger.debug(f"Raw output for RAM speeds: {output}")
- speeds = [int(s) for s in output if s.isdigit()]
- if speeds:
- max_speed = max(speeds)
- bandwidth = max_speed * 8 * 2 / 1000
- logger.debug(f"Windows RAM max speed = {max_speed}, estimated bandwidth = {bandwidth:.2f} GB/s")
- return bandwidth
- logger.info("No known memory speeds found, defaulting to 48 GB/s.")
- return 48
- except Exception as e:
- logger.warning(f"Windows memory detection error: {e}")
- return 48
- elif system == "Linux":
- try:
- cmd = ["sudo", "dmidecode", "-t", "memory"]
- logger.debug(f"Running command: {' '.join(cmd)}")
- output = subprocess.check_output(cmd, timeout=5).decode().split("\n")
- logger.debug(f"Raw output of dmidecode: {output}")
- speeds = []
- for line in output:
- if "Speed:" in line and "Unknown" not in line:
- try:
- spd = line.split(":")[-1].strip().split(" ")[0]
- speeds.append(int(spd))
- except:
- pass
- if speeds:
- max_speed = max(speeds)
- bandwidth = max_speed * 8 * 2 / 1000
- logger.debug(f"Linux RAM max speed = {max_speed}, estimated bandwidth = {bandwidth:.2f} GB/s")
- return bandwidth
- except Exception as e:
- logger.warning(f"dmidecode call failed or timed out: {e}")
- pass
- logger.debug("Falling back to /proc/meminfo-based heuristic.")
- with open('/proc/meminfo', 'r') as f:
- mem_info = f.read().lower()
- if 'memtotal' in mem_info:
- total_kb = int(mem_info.split('memtotal:')[1].split('kb')[0].strip())
- total_gb = total_kb / (1024**2)
- if total_gb >= 32:
- return 64
- else:
- return 48
- # Default if other OS
- logger.info("Unsupported platform for memory detection, defaulting to 48 GB/s.")
- return 48
- except Exception as e:
- logger.error(f"Error retrieving RAM speed: {e}")
- return 48
- ###############################################################################
- # GPU VRAM + BANDWIDTH DETECTION
- ###############################################################################
- def get_vram_specs():
- vram = None
- bandwidth = None
- system = platform.system()
- logger.debug(f"Platform for VRAM detection: {system}")
- if system == "Windows":
- vram = detect_vram_windows()
- elif system == "Linux":
- vram = detect_vram_linux()
- else:
- logger.warning("Unsupported platform for VRAM detection; defaulting VRAM=0.")
- vram = 0
- logger.debug(f"Detected VRAM before bandwidth assignment: {vram}")
- if vram is not None and vram > 0:
- # Heuristic bandwidth assignment
- if vram >= 49:
- bandwidth = 1500
- elif vram >= 25:
- bandwidth = 1790
- elif vram >= 17:
- bandwidth = 950
- elif vram >= 13:
- bandwidth = 550
- elif vram >= 9:
- bandwidth = 400
- elif vram >= 7:
- bandwidth = 300
- elif vram >= 5:
- bandwidth = 240
- else:
- bandwidth = 200
- logger.debug(f"Heuristic GPU bandwidth assigned: {bandwidth} GB/s")
- else:
- logger.warning("VRAM not detected, defaulting to 0 GB and 0 GB/s bandwidth.")
- vram = 0
- bandwidth = 0
- return vram, bandwidth
- def detect_vram_windows():
- logger.debug("Attempting NVIDIA VRAM detection via nvidia-smi.")
- try:
- cmd = ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"]
- logger.debug(f"Running command: {' '.join(cmd)}")
- output = subprocess.check_output(cmd, timeout=5).decode().strip()
- logger.debug(f"nvidia-smi output:\n{output}")
- lines = output.split("\n")
- if lines:
- max_mem_mb = max(float(x) for x in lines if x.strip())
- return max_mem_mb / 1024
- except Exception as e:
- logger.debug(f"Failed nvidia-smi detection: {e}")
- logger.debug("Attempting AMD/Intel VRAM detection via PowerShell WMI.")
- try:
- cmd = ["powershell", "-Command", "Get-WmiObject Win32_VideoController | Select-Object AdapterRAM"]
- logger.debug(f"Running command: {' '.join(cmd)}")
- output = subprocess.check_output(cmd, timeout=5).decode().strip()
- logger.debug(f"WMI output for AdapterRAM:\n{output}")
- for line in output.split('\n'):
- line = line.strip()
- if line.isdigit():
- vram_gb = int(line) / (1024**3)
- logger.debug(f"Detected VRAM from WMI: {vram_gb:.2f} GB")
- return vram_gb
- except Exception as e:
- logger.debug(f"Failed AMD VRAM detection via WMI: {e}")
- logger.debug("Checking for Intel Arc via WMI Description.")
- try:
- cmd = ["powershell", "-Command", "Get-WmiObject Win32_VideoController | Select-Object Description"]
- logger.debug(f"Running command: {' '.join(cmd)}")
- output = subprocess.check_output(cmd, timeout=5).decode().lower()
- logger.debug(f"WMI Description output:\n{output}")
- if 'intel' in output and 'arc' in output:
- if 'a770' in output:
- return 16
- elif 'b580' in output:
- return 12
- elif 'b570' in output:
- return 10
- elif 'a750' in output:
- return 8
- elif 'a380' in output:
- return 6
- elif 'a310' in output:
- return 4
- except Exception as e:
- logger.debug(f"Failed Intel Arc detection via WMI: {e}")
- return None
- def detect_vram_linux():
- logger.debug("Attempting NVIDIA VRAM detection via nvidia-smi on Linux.")
- try:
- cmd = ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"]
- logger.debug(f"Running command: {' '.join(cmd)}")
- output = subprocess.check_output(cmd, timeout=5).decode().strip()
- logger.debug(f"nvidia-smi output:\n{output}")
- lines = output.split("\n")
- if lines:
- max_mem_mb = max(float(x) for x in lines if x.strip())
- return max_mem_mb / 1024
- except Exception as e:
- logger.debug(f"Failed nvidia-smi detection: {e}")
- logger.debug("Attempting AMD VRAM detection via /sys path.")
- try:
- amd_paths = [
- "/sys/class/drm/card0/device/mem_info_vram_total",
- "/sys/class/gpu/card0/device/mem_info_vram_total"
- ]
- for path in amd_paths:
- if os.path.exists(path):
- with open(path, 'r') as f:
- vram_bytes = int(f.read().strip())
- vram_gb = vram_bytes / (1024**3)
- logger.debug(f"Detected AMD VRAM from {path}: {vram_gb:.2f} GB")
- return vram_gb
- except Exception as e:
- logger.debug(f"Failed AMD VRAM detection via /sys: {e}")
- logger.debug("Attempting Intel Arc detection via lspci.")
- try:
- cmd = ["lspci", "-v"]
- logger.debug(f"Running command: {' '.join(cmd)}")
- output = subprocess.check_output(cmd, timeout=5).decode().lower()
- logger.debug(f"lspci output:\n{output}")
- if 'intel' in output and 'arc' in output:
- if 'a770' in output:
- return 16
- elif 'b580' in output:
- return 12
- elif 'b570' in output:
- return 10
- elif 'a750' in output:
- return 8
- elif 'a380' in output:
- return 6
- elif 'a310' in output:
- return 4
- except Exception as e:
- logger.debug(f"Failed Intel Arc detection via lspci: {e}")
- return None
- ###############################################################################
- # KV CACHE CALCULATION WITH REAL CONFIG (FALLBACK if MISSING)
- ###############################################################################
- def estimate_max_context_size(
- leftover_vram_gb, num_layers, hidden_size, kv_bits=16
- ):
- """
- Uses real config to compute memory usage for KV cache:
- memory_per_token (bytes) = 2 * num_layers * hidden_size * (kv_bits/8)
- leftover_vram_bytes = leftover_vram_gb * 1e9
- => max_context = leftover_vram_bytes / memory_per_token
- If either num_layers or hidden_size is missing, we return 0 (let a fallback logic handle it).
- """
- if leftover_vram_gb <= 0:
- return 0
- if not num_layers or not hidden_size:
- # can't do a real formula
- return 0
- leftover_bytes = leftover_vram_gb * 1e9
- mem_per_token = 2.0 * num_layers * hidden_size * (kv_bits / 8.0)
- max_context = leftover_bytes / mem_per_token
- logger.debug(
- f"Using real config for KV: L={num_layers}, hidden_size={hidden_size}, kv_bits={kv_bits}, leftover={leftover_vram_gb:.2f} GB => {max_context:.1f} tokens"
- )
- return int(max_context)
- def estimate_max_context_size_fallback(params_b, leftover_vram_gb, bpw):
- """
- Fallback: guess layer count from param size, guess hidden dim from sqrt,
- then assume kv_bits=16. Same logic as the older approximate method.
- If leftover_vram_gb <= 0 => 0
- """
- if leftover_vram_gb <= 0:
- return 0
- # guess layer count
- if params_b > 30e9:
- L = 60
- elif params_b > 10e9:
- L = 40
- elif params_b > 5e9:
- L = 32
- else:
- L = 28
- # approximate d_model from param_count
- # param_count ~ 2 * L * d_model^2 => d_model ~ sqrt(params_b/(2*L))
- d_model = math.sqrt(params_b / (2.0 * L))
- # we assume KV uses 16 bits even if bpw < 16
- kv_bits = 16
- leftover_bytes = leftover_vram_gb * 1e9
- mem_per_token = 2.0 * L * d_model * (kv_bits / 8.0)
- max_context = leftover_bytes / mem_per_token
- logger.debug(f"Fallback KV calc => L={L}, d_model={d_model:.1f}, leftover={leftover_vram_gb:.2f} => {max_context:.1f} tokens")
- return int(max_context)
- ###############################################################################
- # THROUGHPUT ESTIMATION
- ###############################################################################
- def estimate_tks(ram_bandwidth, required_mem):
- logger.debug(f"Estimating tk/s with ram_bandwidth={ram_bandwidth:.2f} GB/s, required_mem={required_mem:.2f} GB.")
- return (ram_bandwidth / required_mem) * 0.9
- def calculate_tks(base_tks, offload_ratio):
- logger.debug(f"Calculating partial offload tk/s with base_tks={base_tks:.2f}, offload_ratio={offload_ratio:.2f}%.")
- return base_tks * (0.052 * math.exp(4.55 * (100 - offload_ratio) / 100) + 1.06)
- ###############################################################################
- # ANALYSIS
- ###############################################################################
- def analyze_quantization(params_b, vram_gb, bandwidth, ram_gb, quant, bpw, ram_bandwidth,
- num_layers, hidden_size, kv_bits=16):
- """
- For a given quantization (bpw), compute:
- - total memory required for weights (required_mem)
- - run type (all in VRAM, partial offload, etc.)
- - tokens/s (heuristic)
- - leftover VRAM => max_context tokens if run_type is "All in VRAM" or "KV cache offload"
- Returns (run_type, required_mem, offload_ratio, tks, max_ctx).
- """
- logger.debug(f"Analyzing quant={quant}, bpw={bpw}, params_b={params_b}, vram_gb={vram_gb}, bandwidth={bandwidth}")
- required_base = 1.0 + params_b * 0.05 / 1e9 # overhead in GB
- required_mem = required_base + (params_b * bpw / 8.0 / 1e9)
- logger.debug(f"Computed required_mem={required_mem:.2f} GB for quant={quant}.")
- run_type = "Won't run"
- tks = None
- offload_ratio = 0
- max_ctx = 0
- if required_mem <= vram_gb:
- run_type = "All in VRAM"
- tks = bandwidth / required_mem
- leftover_vram_gb = vram_gb - required_mem
- # First try real config-based formula
- if num_layers and hidden_size:
- max_ctx = estimate_max_context_size(leftover_vram_gb, num_layers, hidden_size, kv_bits=kv_bits)
- else:
- # fallback
- max_ctx = estimate_max_context_size_fallback(params_b, leftover_vram_gb, bpw)
- elif required_mem <= vram_gb + 1:
- run_type = "KV cache offload"
- tks = (bandwidth / required_mem) * 0.9
- leftover_vram_gb = (vram_gb + 1) - required_mem
- if leftover_vram_gb < 0:
- leftover_vram_gb = 0
- if num_layers and hidden_size:
- max_ctx = estimate_max_context_size(leftover_vram_gb, num_layers, hidden_size, kv_bits=kv_bits)
- else:
- max_ctx = estimate_max_context_size_fallback(params_b, leftover_vram_gb, bpw)
- elif vram_gb > 1 and required_mem <= (ram_gb + vram_gb):
- run_type = "Partial offload"
- offload_ratio = (required_mem - vram_gb) / required_mem * 100
- base_tks = estimate_tks(ram_bandwidth, required_mem)
- tks = calculate_tks(base_tks, offload_ratio)
- max_ctx = 0 # Not a straightforward leftover VRAM scenario
- elif required_mem <= ram_gb:
- run_type = "All in System RAM"
- offload_ratio = 100
- base_tks = estimate_tks(ram_bandwidth, required_mem)
- tks = base_tks
- max_ctx = 0
- return run_type, required_mem, offload_ratio, tks, max_ctx
- def analyze_all_quantizations(params_b, vram_gb, bandwidth, ram_gb, ram_bandwidth,
- num_layers, hidden_size, kv_bits):
- logger.debug("Analyzing all quantizations.")
- results = {}
- for quant, bpw in QUANTIZATION_BPWS.items():
- run_type, mem_usage, offload_ratio, tks, max_ctx = analyze_quantization(
- params_b, vram_gb, bandwidth, ram_gb, quant, bpw, ram_bandwidth,
- num_layers, hidden_size, kv_bits
- )
- results[quant] = {
- "run_type": run_type,
- "memory_required_gb": mem_usage,
- "offload_percentage": offload_ratio,
- "tokens_per_s": tks,
- "max_context_tokens": max_ctx
- }
- return results
- ###############################################################################
- # MAIN
- ###############################################################################
- def main():
- args = parse_args()
- # Logging
- if args.debug:
- logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
- logger.setLevel(logging.DEBUG)
- else:
- logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)
- if args.quiet:
- logger.setLevel(logging.WARNING)
- logger.debug(f"Parsed arguments: {args}")
- # Model ID
- model_id = args.model_id
- if not model_id:
- model_id = input("Enter Hugging Face model ID (e.g., microsoft/phi-4): ")
- logger.info(f"Using user-provided model_id: {model_id}")
- # 1) Retrieve model params
- if args.params:
- param_text = args.params
- logger.debug(f"User provided params: {param_text}")
- else:
- param_text = get_model_params_hfapi(model_id)
- if not param_text:
- param_text = get_model_params_scrape(model_id)
- if not param_text:
- logger.error("Could not determine model parameters from HF API or page.")
- output_results(None, None, None, None, None, None, args.output)
- return
- # 2) Convert param string to numeric param count
- params_b = convert_params_to_b(param_text)
- if not params_b:
- logger.error(f"Failed to parse parameter string into a numeric value: '{param_text}'")
- output_results(None, None, None, None, None, None, args.output)
- return
- # 3) Retrieve config for real KV calculation if available
- num_layers, hidden_size, kv_bits = get_model_config_details(model_id)
- # 4) System detection (RAM, VRAM, Bandwidth)
- total_ram = get_ram_specs()
- vram, bandwidth = get_vram_specs()
- logger.debug(f"Detected VRAM = {vram} GB, GPU bandwidth = {bandwidth} GB/s")
- if args.vram is not None:
- logger.info(f"Overriding detected VRAM with user-supplied value: {args.vram} GB")
- vram = args.vram
- if not bandwidth or bandwidth == 0:
- logger.warning("GPU bandwidth not detected or is 0; defaulting to 200 GB/s.")
- bandwidth = 200
- if args.bandwidth is not None:
- logger.info(f"Overriding GPU bandwidth with user-supplied value: {args.bandwidth} GB/s")
- bandwidth = args.bandwidth
- # Handle multi-GPU
- if args.num_gpus > 1:
- vram_total = vram * args.num_gpus
- bandwidth_total = (bandwidth * args.num_gpus) * 0.42
- logger.debug(f"Multi-GPU scenario => vram_total={vram_total}, bandwidth_total={bandwidth_total}")
- vram = vram_total
- bandwidth = bandwidth_total
- ram_bandwidth = get_memory_bandwidth()
- # 5) Analyze
- results = analyze_all_quantizations(params_b, vram, bandwidth, total_ram, ram_bandwidth,
- num_layers, hidden_size, kv_bits)
- # Output
- output_results(
- model_id,
- param_text,
- params_b,
- vram,
- bandwidth,
- results,
- args.output,
- total_ram=total_ram,
- ram_bandwidth=ram_bandwidth
- )
- def output_results(model_id, param_text, params_b, vram, bandwidth, results, mode,
- total_ram=None, ram_bandwidth=None):
- if not results:
- if mode == "json":
- print(json.dumps({"error": "No results available"}, indent=2))
- else:
- print("No results available.")
- return
- summary = {
- "model_id": model_id,
- "param_text": param_text,
- "params_count": params_b,
- "system_ram_gb": total_ram,
- "vram_gb": vram,
- "gpu_bandwidth_gb_s": bandwidth,
- "ram_bandwidth_gb_s": ram_bandwidth,
- "quantization_analysis": results
- }
- if mode == "json":
- print(json.dumps(summary, indent=2))
- else:
- print(f"Model: {model_id}")
- print(f"Model Parameters (raw): {param_text}")
- if params_b:
- print(f"Converted Param Count: {params_b / 1e9:.2f}B parameters")
- print(f"System RAM: {total_ram:.2f} GB")
- print(f"Detected VRAM: {vram:.2f} GB")
- print(f"GPU Bandwidth (approx): {bandwidth:.2f} GB/s")
- print(f"System RAM Bandwidth (approx): {ram_bandwidth:.2f} GB/s")
- print("\nAnalysis per quantization level:")
- for quant, data in results.items():
- run_type = data["run_type"]
- mem_req = data["memory_required_gb"]
- offload_pct = data["offload_percentage"]
- tks = data["tokens_per_s"]
- max_ctx = data["max_context_tokens"]
- print(f"\nQuantization: {quant.upper()}")
- print(f" - Run Type: {run_type}")
- print(f" - Memory Required: {mem_req:.2f} GB")
- if offload_pct > 0:
- print(f" - GPU Offload Percentage: {100 - offload_pct:.1f}% in GPU")
- if tks:
- print(f" - Estimated tk/s: {tks:.2f}")
- if max_ctx > 0:
- print(f" - Estimated Max Context Size (tokens) in leftover VRAM: {max_ctx}")
- else:
- print(" - Estimated Max Context Size (tokens): 0 (N/A or partial offload)")
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement