Guest User

dgt

a guest
Jul 26th, 2025
37
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.30 KB | Movies | 0 0
  1. """
  2. Grand‑Tour‑LLM
  3. ==============
  4. A *stand‑alone* Python script that automates a grand tour of an n‑dimensional
  5. dataset and lets an OpenAI *vision* model pick the k most interesting 2‑D
  6. projections.
  7.  
  8. Usage (after `pip install -r requirements.txt` and setting $OPENAI_API_KEY):
  9.  
  10.    python grand_tour_llm.py data.csv --steps 5000 --out ./results \
  11.        --model gpt-4o-vision-preview \
  12.        --prompt "You are looking for small isolated clusters. Rate..."
  13.  
  14. Arguments
  15. ---------
  16. Positional
  17.    csv_path          Path to a **numeric** CSV file (rows=samples, columns=features).
  18.  
  19. Optional
  20.    -o, --out         Output directory (default ./llm_tour_results)
  21.    -s, --steps       Frames to explore (default 100)
  22.    -m, --model       ChatCompletion model name (default gpt-4o-vision-preview)
  23.    -k, --keep        Override k (top views kept). Default = round(sqrt(d)).
  24.    -p, --prompt      Override the system prompt that rates plots.
  25.    --min-div         Minimum Grassmann distance between kept planes (default 0.15)
  26.    --seed            Random seed.
  27.  
  28. Outputs
  29. -------
  30. *out_dir*/
  31.    view_01.png, view_02.png, ...  – PNGs of top‑k scatterplots
  32.    views.json                     – JSON with rating, summary, basis (flattened) per view
  33.  
  34. Requirements
  35. ------------
  36. openai>=1.14.0, numpy, pandas, scipy, scikit‑learn, matplotlib, pillow, tqdm
  37.  
  38. """
  39. from __future__ import annotations
  40. import os, io, json, base64, argparse, heapq, math, time, random
  41. from pathlib import Path
  42.  
  43. import numpy as np
  44. import pandas as pd
  45. from tqdm import trange
  46.  
  47. from scipy.linalg import subspace_angles
  48. from sklearn.preprocessing import RobustScaler
  49. from sklearn.decomposition import FastICA
  50.  
  51. import matplotlib
  52. matplotlib.use("Agg")  # headless
  53. import matplotlib.pyplot as plt
  54.  
  55. import openai
  56.  
  57. # ---------------------------------------------------------------------------
  58. # ------------------------------  HELPERS  -----------------------------------
  59. # ---------------------------------------------------------------------------
  60.  
  61. def robust_scale(X: np.ndarray) -> np.ndarray:
  62.     """Center by median and scale by IQR (RobustScaler)."""
  63.     scaler = RobustScaler().fit(X)
  64.     return scaler.transform(X)
  65.  
  66.  
  67. def initial_plane_ica(X: np.ndarray) -> np.ndarray:
  68.     """Return p×2 orthonormal basis spanning the first two independent components."""
  69.     ica = FastICA(n_components=2, whiten="unit-variance", random_state=0)
  70.     ica.fit(X)
  71.     W = ica.components_.T  # shape (p, 2)
  72.     # Orthonormalise via QR so later Grassmann distances behave nicely
  73.     Q, _ = np.linalg.qr(W)
  74.     return Q[:, :2]
  75.  
  76.  
  77. def random_plane(p: int) -> np.ndarray:
  78.     """Return a p×2 random orthonormal basis."""
  79.     A = np.random.normal(size=(p, 2))
  80.     Q, _ = np.linalg.qr(A)
  81.     return Q[:, :2]
  82.  
  83.  
  84. def render_scatter(points: np.ndarray, labels=None, size=(512, 512)) -> bytes:
  85.     """Render 2‑D points → PNG bytes."""
  86.     fig = plt.figure(figsize=(size[0] / 100, size[1] / 100), dpi=100)
  87.     ax = fig.add_subplot(111)
  88.     ax.scatter(points[:, 0], points[:, 1], s=6, c=labels, cmap="tab10", alpha=0.8)
  89.     ax.set_xticks([])
  90.     ax.set_yticks([])
  91.     ax.set_title("")
  92.     buf = io.BytesIO()
  93.     fig.savefig(buf, format="png", bbox_inches="tight")
  94.     plt.close(fig)
  95.     buf.seek(0)
  96.     return buf.read()
  97.  
  98.  
  99. def grassmann_dist(B1: np.ndarray, B2: np.ndarray) -> float:
  100.     """Principal‑angle (Riemannian) distance between two 2‑planes in R^p."""
  101.     theta = subspace_angles(B1, B2)
  102.     return float(np.linalg.norm(theta))
  103.  
  104. # ---------------------------------------------------------------------------
  105. # ---------------------------  LLM INTERACTION  ------------------------------
  106. # ---------------------------------------------------------------------------
  107. DEFAULT_PROMPT = (
  108.     "You are a data scientists visualizing scatter plots to get insight. "
  109.     "Rate the scatterplot's interestingness from 1 (dull) to 10 (very revealing) "
  110.     "*numerically* under the key 'rating' and explain why in <100 words under the key 'summary'. "
  111.     "Look for clustering of data points, both major clustering and smallish isolated clusters, "
  112.     "spot linear or non-linear dependencies, both global and applying only to a cluster. "
  113.     "Return a *valid JSON object* with keys 'rating' (int) and 'summary' (str)."
  114. )
  115.  
  116.  
  117. def rate_with_llm(img_bytes: bytes, model: str, system_prompt: str) -> tuple[int, str]:
  118.     """Call the chat model on one image. Returns (rating, summary)."""
  119.     b64 = base64.b64encode(img_bytes).decode()
  120.     client = openai.OpenAI()
  121.     messages = [
  122.         {"role": "system", "content": system_prompt},
  123.         {
  124.             "role": "user",
  125.             "content": [
  126.                 {"type": "text", "text": "Here is the scatterplot:"},
  127.                 {
  128.                     "type": "image_url",
  129.                     "image_url": {"url": f"data:image/png;base64,{b64}", "detail": "low"},
  130.                 },
  131.             ],
  132.         },
  133.     ]
  134.     resp = client.chat.completions.create(model=model, messages=messages)
  135.     try:
  136.         js = json.loads(resp.choices[0].message.content)
  137.         return int(js["rating"]), str(js["summary"])
  138.     except Exception as e:  # noqa: BLE001
  139.         # Fallback: treat any parse failure as '1 – dull'
  140.         return 1, f"(parse‑error) {e}: {resp.choices[0].message.content[:60]}..."
  141.  
  142. # ---------------------------------------------------------------------------
  143. # ----------------------------  TOP‑K HEAP  ----------------------------------
  144. # ---------------------------------------------------------------------------
  145. class TopKViews:
  146.     def __init__(self, k: int, min_div: float):
  147.         self.k = k
  148.         self.min_div = min_div
  149.         self.heap: list[tuple[int, np.ndarray, dict]] = []  # (score, basis, record)
  150.  
  151.     def maybe_add(self, record: dict):
  152.         score, basis = record["score"], record["basis"]
  153.         # Diversity check
  154.         if any(grassmann_dist(basis, t[1]) < self.min_div for t in self.heap):
  155.             return
  156.         if len(self.heap) < self.k:
  157.             heapq.heappush(self.heap, (score, basis, record))
  158.         elif score > self.heap[0][0]:
  159.             heapq.heapreplace(self.heap, (score, basis, record))
  160.  
  161.     def results(self):
  162.         return sorted([t[2] for t in self.heap], key=lambda r: -r["score"])
  163.  
  164. # ---------------------------------------------------------------------------
  165. # ------------------------------  MAIN LOOP  ---------------------------------
  166. # ---------------------------------------------------------------------------
  167.  
  168. def run_tour(
  169.     X: np.ndarray,
  170.     steps: int,
  171.     model: str,
  172.     prompt: str,
  173.     k: int,
  174.     min_div: float = 0.15,
  175.     seed: int | None = None,
  176. ) -> list[dict]:
  177.     if seed is not None:
  178.         np.random.seed(seed)
  179.         random.seed(seed)
  180.  
  181.     n, p = X.shape
  182.     # Initial plane seeded by ICA
  183.     basis0 = initial_plane_ica(X)
  184.  
  185.     topk = TopKViews(k, min_div)
  186.  
  187.     for t in trange(steps, desc="Tour", unit="view"):
  188.         if t == 0:
  189.             B = basis0
  190.         else:
  191.             B = random_plane(p)
  192.         proj = X @ B  # shape (n, 2)
  193.         png = render_scatter(proj)
  194.         score, summary = rate_with_llm(png, model, prompt)
  195.         rec = {
  196.             "score": score,
  197.             "summary": summary,
  198.             "basis": B.tolist(),  # easier to json‑dump later
  199.             "image_bytes": png,
  200.         }
  201.         topk.maybe_add(rec)
  202.     return topk.results()
  203.  
  204. # ---------------------------------------------------------------------------
  205. # -------------------------------  I/O  --------------------------------------
  206. # ---------------------------------------------------------------------------
  207.  
  208. def save_results(results: list[dict], out_dir: Path):
  209.     out_dir.mkdir(parents=True, exist_ok=True)
  210.     meta = []
  211.     for rank, rec in enumerate(results, 1):
  212.         fname = out_dir / f"view_{rank:02d}.png"
  213.         with open(fname, "wb") as f:
  214.             f.write(rec["image_bytes"])
  215.         meta.append(
  216.             {
  217.                 "rank": rank,
  218.                 "score": rec["score"],
  219.                 "summary": rec["summary"],
  220.                 "png": fname.name,
  221.                 "basis_flat": np.array(rec["basis"]).flatten().tolist(),
  222.             }
  223.         )
  224.     with open(out_dir / "views.json", "w", encoding="utf-8") as f:
  225.         json.dump(meta, f, indent=2)
  226.  
  227. # ---------------------------------------------------------------------------
  228. # ------------------------------  SCRIPT  ------------------------------------
  229. # ---------------------------------------------------------------------------
  230.  
  231. def parse_args():
  232.     ap = argparse.ArgumentParser(description="LLM‑guided grand tour")
  233.     ap.add_argument("csv_path", type=str, help="Numeric CSV file to explore")
  234.     ap.add_argument("-o", "--out", default="./llm_tour_results", type=str)
  235.     ap.add_argument("-s", "--steps", default=100, type=int, help="Frames to sample")
  236.     ap.add_argument("-m", "--model", default="gpt-4o-mini", type=str)
  237.     ap.add_argument("-k", "--keep", type=int, help="Top‑k views to keep (default 3*sqrt(d))")
  238.     ap.add_argument("-p", "--prompt", type=str, default=DEFAULT_PROMPT)
  239.     ap.add_argument("--min-div", type=float, default=0.15)
  240.     ap.add_argument("--seed", type=int)
  241.     return ap.parse_args()
  242.  
  243.  
  244. def main():
  245.     args = parse_args()
  246.  
  247.     # OpenAI key – prefer env var, fallback CLI (for dev convenience)
  248.     if not openai.api_key:  # empty string evaluates False
  249.         openai.api_key = os.getenv("OPENAI_API_KEY")
  250.     if not openai.api_key:
  251.         raise RuntimeError("Set OPENAI_API_KEY environment variable or openai.api_key")
  252.  
  253.     df = pd.read_csv(args.csv_path)
  254.     X = df.select_dtypes(include=["number"]).values
  255.     if X.size == 0:
  256.         raise ValueError("CSV does not contain numeric columns.")
  257.  
  258.     X = robust_scale(X)
  259.     d = X.shape[1]
  260.     k = args.keep if args.keep is not None else max(1, round(3*math.sqrt(d)))
  261.     print(f"Data shape: n={X.shape[0]}, d={d}; keeping top‑k={k} views")
  262.  
  263.     results = run_tour(
  264.         X,
  265.         steps=args.steps,
  266.         model=args.model,
  267.         prompt=args.prompt,
  268.         k=k,
  269.         min_div=args.min_div,
  270.         seed=args.seed,
  271.     )
  272.  
  273.     out_dir = Path(args.out)
  274.     save_results(results, out_dir)
  275.     print(f"Saved {len(results)} views → {out_dir.resolve()}")
  276.  
  277.  
  278. if __name__ == "__main__":
  279.     main()
  280.  
Add Comment
Please, Sign In to add comment