Advertisement
3th1ca14aX0r

SigilDERG Student Model V6 Proposed

Jun 19th, 2025
717
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.58 KB | None | 0 0
  1. #!/usr/bin/env python
  2. """
  3. unified_sigil_trainer_v6.py
  4.  
  5. Trains and exports an ONNX student model with enhanced relational reasoning
  6. and teacher–student distillation support. This version is built to handle
  7. rich, multi-dimensional data—such as that gathered by integrating rust-crate-pipeline
  8. with Crawl4AI.
  9.  
  10. Key features:
  11.  - Dynamic feature selection via a --feature_cols argument.
  12.  - Two architecture variants: a simple “flat” MLP and a relational variant that processes
  13.    feature-by-feature tokens through a lightweight Transformer encoder.
  14.  - Optional teacher–student distillation: if a teacher_logits file is provided, a KD loss is applied.
  15.  - Manifest generation and optional GPG signing for deployment.
  16.  
  17. Usage examples:
  18.  python unified_sigil_trainer_v6.py --mode trust --csv enriched_data.csv --onnx trust_model.onnx \
  19.       --feature_cols x0,x1,x2,x3,x4,x5,x6,x7
  20.  python unified_sigil_trainer_v6.py --mode classify --csv enriched_classify.csv --onnx classify_model.onnx \
  21.       --feature_cols f1,f2,f3,f4,f5,f6,f7,f8,f9 --relational --teacher_logits teacher_logits.npy --kd_weight 0.5 --sign
  22. """
  23.  
  24. import argparse, csv, hashlib, json, pathlib, subprocess
  25. from datetime import datetime
  26.  
  27. import numpy as np
  28. import torch
  29. import torch.nn as nn
  30. import torch.nn.functional as F
  31. import torch.utils.data as td
  32. from tqdm import tqdm
  33.  
  34. # ----------------- Global Hyperparameters -----------------
  35. HIDDEN_DIM = 32         # Increase hidden dimensions if richer data permits
  36. NUM_CLASSES = 3         # For classification mode (otherwise regression/trust)
  37. EPOCHS = 25
  38. BATCH = 128
  39. LEARNING_RT = 1e-3
  40. VERSION = "v6-unified"
  41.  
  42. # ----------------- Model Definitions -----------------
  43.  
  44. # Flat MLP architecture that uses dynamic input dimensions.
  45. class UnifiedSigilNet(nn.Module):
  46.     def __init__(self, input_dim, mode="trust"):
  47.         super().__init__()
  48.         self.mode = mode
  49.         self.input_dim = input_dim
  50.         self.norm = nn.LayerNorm(input_dim)
  51.         self.fc1 = nn.Linear(input_dim, HIDDEN_DIM)
  52.         self.fc2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM if mode == "classify" else 1)
  53.         self.out = nn.Linear(HIDDEN_DIM, NUM_CLASSES) if mode == "classify" else nn.Identity()
  54.  
  55.     def forward(self, x):
  56.         x = self.norm(x)
  57.         x = F.gelu(self.fc1(x))
  58.         x = F.dropout(x, p=0.10, training=self.training)
  59.         x = F.gelu(self.fc2(x))
  60.         if self.mode == "trust":
  61.             score = torch.sigmoid(x)
  62.             # A simple proxy for confidence
  63.             confidence = (score - 0.5).abs() * 2
  64.             return score, confidence
  65.         else:
  66.             return self.out(x)
  67.  
  68. # Relational variant that treats each input feature as a token.
  69. class RelationalSigilNet(nn.Module):
  70.     def __init__(self, input_dim, mode="trust"):
  71.         super().__init__()
  72.         self.mode = mode
  73.         self.input_dim = input_dim
  74.         # Each feature is projected into a token embedding of size HIDDEN_DIM.
  75.         self.token_embed_dim = HIDDEN_DIM
  76.         self.feature_proj = nn.Linear(1, self.token_embed_dim)
  77.         # Learnable positional embeddings (one per input feature)
  78.         self.positional = nn.Parameter(torch.zeros(input_dim, self.token_embed_dim))
  79.         # Single-layer Transformer encoder with 1 attention head.
  80.         self.transformer = nn.TransformerEncoderLayer(
  81.             d_model=self.token_embed_dim, nhead=1, batch_first=True
  82.         )
  83.         # Pool across tokens and process the aggregated representation.
  84.         self.fc = nn.Sequential(
  85.             nn.LayerNorm(self.token_embed_dim),
  86.             nn.Linear(self.token_embed_dim, HIDDEN_DIM),
  87.             nn.GELU(),
  88.             nn.Dropout(0.10),
  89.             nn.Linear(HIDDEN_DIM, HIDDEN_DIM if mode == "classify" else 1)
  90.         )
  91.         self.out = nn.Linear(HIDDEN_DIM, NUM_CLASSES) if mode == "classify" else nn.Identity()
  92.  
  93.     def forward(self, x):
  94.         # x shape: (batch, input_dim)
  95.         x = x.unsqueeze(-1)  # Now shape: (batch, input_dim, 1)
  96.         x = self.feature_proj(x)  # Shape: (batch, input_dim, token_embed_dim)
  97.         x = x + self.positional  # Apply positional bias
  98.         x = self.transformer(x)  # Process relationally across features
  99.         x = x.mean(dim=1)       # Pool features (mean pooling)
  100.         x = self.fc(x)
  101.         if self.mode == "trust":
  102.             score = torch.sigmoid(x)
  103.             confidence = (score - 0.5).abs() * 2
  104.             return score, confidence
  105.         else:
  106.             return self.out(x)
  107.  
  108. # ----------------- Data Loading -----------------
  109.  
  110. def load_csv(path, feature_cols, mode, teacher_logits_file=None):
  111.     """
  112.    Load a CSV file and only use the columns specified in feature_cols.
  113.    The CSV should also have a target column "y" for training.
  114.    If teacher_logits_file is provided, it is assumed to be an npy file where the
  115.    i-th row corresponds to the enriched target from the teacher.
  116.    """
  117.     xs, ys, teacher_targets = [], [], []
  118.     teacher_all = None
  119.     if teacher_logits_file:
  120.         teacher_all = np.load(teacher_logits_file)
  121.     with open(path, newline='') as fh:
  122.         rdr = csv.DictReader(fh)
  123.         for i, row in enumerate(rdr):
  124.             try:
  125.                 # Dynamically gather features from the specified columns.
  126.                 vec = [float(row[col]) for col in feature_cols if col in row]
  127.             except Exception as ex:
  128.                 continue
  129.             # Skip rows with missing data.
  130.             if any(np.isnan(vec)):
  131.                 continue
  132.             xs.append(vec)
  133.             # The target must be in the "y" column.
  134.             ys.append(float(row["y"]) if mode == "trust" else int(row["y"]))
  135.             if teacher_all is not None:
  136.                 teacher_targets.append(teacher_all[i])
  137.     xs = np.array(xs, dtype=np.float32)
  138.     # Normalize features columnwise.
  139.     xs_min = xs.min(axis=0)
  140.     xs_max = xs.max(axis=0)
  141.     xs_norm = (xs - xs_min) / (xs_max - xs_min + 1e-8)
  142.     x = torch.tensor(xs_norm, dtype=torch.float32)
  143.     if mode == "trust":
  144.         y = torch.tensor(ys, dtype=torch.float32)
  145.     else:
  146.         y = torch.tensor(ys, dtype=torch.long)
  147.     if teacher_all is not None:
  148.         teacher_tensor = torch.tensor(np.array(teacher_targets), dtype=torch.float32)
  149.         return td.TensorDataset(x, y, teacher_tensor)
  150.     else:
  151.         return td.TensorDataset(x, y)
  152.  
  153. # ----------------- Training Loop -----------------
  154.  
  155. def train(model, loader, mode, kd_weight=0.0):
  156.     """
  157.    Training loop with support for teacher–student distillation.
  158.    If kd_weight > 0, expect the loader to return (x, y, teacher_target).
  159.    """
  160.     base_loss_fn = nn.BCELoss() if mode == "trust" else nn.CrossEntropyLoss()
  161.     kd_loss_fn = nn.MSELoss() if mode == "trust" else nn.KLDivLoss(reduction="batchmean")
  162.    
  163.     optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RT)
  164.     for epoch in range(EPOCHS):
  165.         model.train()
  166.         running_loss = 0.0
  167.         pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
  168.         for batch in pbar:
  169.             optimizer.zero_grad()
  170.             if len(batch) == 3:
  171.                 xb, yb, teacher_target = batch
  172.             else:
  173.                 xb, yb = batch
  174.                 teacher_target = None
  175.             out = model(xb)
  176.             if mode == "trust":
  177.                 # For trust mode, use the predicted score.
  178.                 pred = out[0].squeeze(1)
  179.             else:
  180.                 pred = out
  181.             loss = base_loss_fn(pred, yb)
  182.             if teacher_target is not None and kd_weight > 0:
  183.                 if mode == "trust":
  184.                     kd_loss = kd_loss_fn(pred, teacher_target)
  185.                 else:
  186.                     kd_loss = kd_loss_fn(F.log_softmax(pred, dim=-1), F.softmax(teacher_target, dim=-1))
  187.                 loss = (1 - kd_weight) * loss + kd_weight * kd_loss
  188.             loss.backward()
  189.             optimizer.step()
  190.             running_loss += loss.item()
  191.             pbar.set_postfix({"loss": running_loss / (pbar.n + 1)})
  192.     return model
  193.  
  194. # ----------------- ONNX Export -----------------
  195.  
  196. def export_onnx(model, path, mode):
  197.     model.eval()
  198.     dummy = torch.randn(1, model.input_dim)
  199.     if mode == "trust":
  200.         output_names = ["score", "confidence"]
  201.     else:
  202.         output_names = ["logits"]
  203.     torch.onnx.export(
  204.         model,
  205.         dummy,
  206.         path,
  207.         input_names=["x"],
  208.         output_names=output_names,
  209.         dynamic_axes={"x": {0: "batch_size"}},
  210.         opset_version=17,
  211.     )
  212.  
  213. # ----------------- Manifest + Signing -----------------
  214.  
  215. def hash_file(path):
  216.     with open(path, "rb") as f:
  217.         return hashlib.sha256(f.read()).hexdigest()
  218.  
  219. def write_manifest(path, sha, mode, input_dim, hidden_dim, classes, version):
  220.     manifest = {
  221.         pathlib.Path(path).name: sha,
  222.         "version": version,
  223.         "mode": mode,
  224.         "input_dim": input_dim,
  225.         "hidden_dim": hidden_dim,
  226.         "classes": classes,
  227.         "timestamp": datetime.now().isoformat(),
  228.     }
  229.     manifest_path = pathlib.Path(path).with_name("model_manifest.json")
  230.     with open(manifest_path, "w") as f:
  231.         json.dump(manifest, f, indent=2)
  232.  
  233. def sign_model(path):
  234.     try:
  235.         subprocess.run(["gpg", "--detach-sign", "--armor", path], check=True)
  236.         print("GPG signature created.")
  237.     except Exception as e:
  238.         print("GPG signing failed:", e)
  239.  
  240. # ----------------- Main -----------------
  241.  
  242. if __name__ == "__main__":
  243.     ap = argparse.ArgumentParser()
  244.     ap.add_argument("--mode", choices=["trust", "classify"], required=True, help="Operation mode")
  245.     ap.add_argument("--csv", required=True, help="CSV file with enriched training data")
  246.     ap.add_argument("--onnx", required=True, help="Output ONNX model file")
  247.     ap.add_argument("--feature_cols", required=False, default="", help="Comma-separated feature column names (e.g., x0,x1,x2)")
  248.     ap.add_argument("--sign", action="store_true", help="GPG-sign the exported ONNX model")
  249.     ap.add_argument("--relational", action="store_true", help="Use the relational transformer variant")
  250.     ap.add_argument("--teacher_logits", default=None, help="Optional teacher logits file (npy)")
  251.     ap.add_argument("--kd_weight", type=float, default=0.0, help="Weight for KD loss (0 disables distillation)")
  252.     args = ap.parse_args()
  253.  
  254.     # Determine features: either use provided list or (as a fallback) default to a small set.
  255.     if args.feature_cols:
  256.         feature_cols = [col.strip() for col in args.feature_cols.split(",")]
  257.     else:
  258.         # Defaults – adjust as needed if enrichment adds more columns.
  259.         feature_cols = ["x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7"]
  260.    
  261.     input_dim = len(feature_cols)
  262.    
  263.     # Choose the student model architecture.
  264.     if args.relational:
  265.         model = RelationalSigilNet(input_dim=input_dim, mode=args.mode)
  266.     else:
  267.         model = UnifiedSigilNet(input_dim=input_dim, mode=args.mode)
  268.    
  269.     # Load the dataset (with teacher logits if provided).
  270.     dataset = load_csv(args.csv, feature_cols, args.mode, teacher_logits_file=args.teacher_logits)
  271.     loader = td.DataLoader(dataset, batch_size=BATCH, shuffle=True)
  272.    
  273.     # Train the model.
  274.     model = train(model, loader, args.mode, kd_weight=args.kd_weight)
  275.    
  276.     # Export the trained model to ONNX.
  277.     export_onnx(model, args.onnx, args.mode)
  278.     sha = hash_file(args.onnx)
  279.     classes = NUM_CLASSES if args.mode == "classify" else 1
  280.     write_manifest(args.onnx, sha, args.mode, input_dim, HIDDEN_DIM, classes, VERSION)
  281.     if args.sign:
  282.         sign_model(args.onnx)
  283.     print(f"Model and manifest ready for mode: {args.mode}, using features: {feature_cols}")
  284.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement