Advertisement
GninraelEnihcam

DeepSynapse.py

Feb 7th, 2025
135
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 33.24 KB | Spirit | 0 0
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. =====================================================================
  5. DeepSynapse v4.1: Comprehensive Phase-Controlled Training System
  6. =====================================================================
  7. DeepSynapse (Dynamic Error-enhanced, Emergent Phase-controlled Self-Optimizing
  8. Neural Adaptive Processing System Engine) is an advanced reinforcement learning
  9. framework designed for training language models that can autonomously evolve
  10. their capabilities. By combining adaptive parameter expansion, multi-objective reward
  11. optimization, and a phase-controlled curriculum, DeepSynapse achieves robust,
  12. interpretable, and error-resilient structured reasoning.
  13.  
  14. Core Innovations:
  15. ------------------
  16. 1. Dynamic LoRA-Head Scaling with Meta-Contextual Adaptation:
  17. - Implements phase-progressive adapter rank expansion (64 → 128 → 192, etc.)
  18. with smoothing to ensure stable transitions.
  19. - Adapts capacity based on context embeddings derived from current batches.
  20.  
  21. 2. Triple Distractor Anchoring:
  22. - Generates multi-modal distractors:
  23. * Numeric: ±20% variance, sign flips, rounding modifications.
  24. * Semantic: Context-aware synonym rotations using WordNet.
  25. * Unit: Multi-dimensional conversion using predefined mappings.
  26.  
  27. 3. KL-Temperature Co-Regulation:
  28. - Uses a cosine-decaying temperature (0.9 → 0.3) along with phase-aligned KL divergence penalties.
  29. - Helps prevent reward hacking and maintains stable output distributions.
  30.  
  31. 4. Reinforced Critique Validation:
  32. - Extracts dedicated <critique> blocks and evaluates them via a RoBERTa-based classifier.
  33. - Applies phase-dependent penalties if self-critique does not match solution correctness.
  34.  
  35. 5. Phase-Controlled Curriculum & Component Locking:
  36. - Three-stage training curriculum:
  37. * Phase 0: Structural Compliance.
  38. * Phase 1: Reasoning Validation.
  39. * Phase 2: Precision Refinement.
  40. - Dynamically activates/deactivates reward components based on phase.
  41.  
  42. 6. Omnidirectional Reward Fusion & Calibration:
  43. - Computes a 5-dimensional reward vector (Structure, Contrastive, Critique, Correctness, KL).
  44. - Calibrates raw rewards with a neural weight allocator and evolves reward functions based on training history.
  45.  
  46. 7. XML Structural Guardian:
  47. - Enforces strict XML formatting (<reasoning>, <answer>, and <critique> tags).
  48. - Applies dynamic length penalties to discourage verbosity.
  49.  
  50. 8. Integrated Performance Monitoring:
  51. - Fully integrated with Weights & Biases (W&B) for real-time telemetry and debugging.
  52.  
  53. Advanced Roadmap:
  54. -----------------
  55. Future evolutions include memory-augmented networks for long-term retention, dynamic gradient accumulation based on advanced metrics, and a fully self-evolving reward system.
  56.  
  57. Expected Emergent Capabilities:
  58. --------------------------------
  59. - Enhanced counterfactual reasoning, robust self-debugging, and superior zero-shot problem-solving.
  60. """
  61.  
  62. import re
  63. import torch
  64. import wandb
  65. import random
  66. import numpy as np
  67. import unittest
  68. from datasets import load_dataset, Dataset
  69. from trl import GRPOConfig, GRPOTrainer
  70. from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
  71. from unsloth import FastLanguageModel, is_bfloat16_supported
  72. from vllm import SamplingParams
  73.  
  74. # Additional Imports for Distractor Generation
  75. import nltk
  76. try:
  77. nltk.data.find('corpora/wordnet')
  78. except LookupError:
  79. nltk.download('wordnet')
  80. from nltk.corpus import wordnet
  81. from itertools import chain
  82.  
  83. # -----------------------------------------------------------------------------
  84. # Configuration Constants
  85. # -----------------------------------------------------------------------------
  86. MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
  87. REF_MODEL_NAME = "Qwen/Qwen2.5-3B" # Reference model for KL divergence
  88. MAX_SEQ_LENGTH = 2048
  89. INITIAL_LORA_RANK = 64
  90. LORA_RANK_INCREMENT = 64 # Progression: 64 -> 128 -> 192, etc.
  91. PHASE_TRANSITION_STEPS = 300
  92.  
  93. # Phase-specific Reward Weights for 3 phases (0, 1, 2)
  94. PHASE_WEIGHTS = {
  95. 'structure': [0.6, 0.3, 0.1],
  96. 'contrastive': [0.0, 0.4, 0.2],
  97. 'critique': [0.1, 0.2, 0.3],
  98. 'correctness': [0.1, 0.3, 0.6],
  99. 'kl': [0.0, 0.1, 0.2],
  100. }
  101.  
  102. SYSTEM_PROMPT = """Respond using structured reasoning followed by a concise answer:
  103. <reasoning>
  104. Step-by-step logical explanation...
  105. </reasoning>
  106. <answer>
  107. Final numerical answer only
  108. </answer>
  109. <critique>
  110. Your self-critique here.
  111. </critique>"""
  112.  
  113. # -----------------------------------------------------------------------------
  114. # Advanced Components
  115. # -----------------------------------------------------------------------------
  116.  
  117. # 1. Hybrid Modular Memory: Memory-augmented neural network (MANN)
  118. class NeuralMemoryBank:
  119. def __init__(self, model_dim=1024):
  120. self.memory = [] # Stores (key, value) pairs as tensors
  121. self.attention = torch.nn.MultiheadAttention(embed_dim=model_dim, num_heads=4)
  122.  
  123. def retrieve(self, query, k=3):
  124. # query: tensor of shape (embed_dim,)
  125. query = query.unsqueeze(0) # (1, embed_dim)
  126. if not self.memory:
  127. return query.squeeze(0)
  128. keys = torch.stack([m[0] for m in self.memory]) # (N, embed_dim)
  129. values = torch.stack([m[1] for m in self.memory]) # (N, embed_dim)
  130. # Reshape keys and values to (N, 1, embed_dim)
  131. keys = keys.unsqueeze(1)
  132. values = values.unsqueeze(1)
  133. query = query.unsqueeze(0) # (1, 1, embed_dim)
  134. attn_output, _ = self.attention(query, keys, values)
  135. return attn_output.squeeze(0)[:k]
  136.  
  137. def store(self, key, value):
  138. self.memory.append((key.detach(), value.detach()))
  139. if len(self.memory) > 1000: # Use FIFO when memory is full
  140. self.memory.pop(0)
  141.  
  142. # 2. Meta-Contextual Adaptation: Lightweight hypernetwork for LoRA rank scaling
  143. class HyperNetwork(torch.nn.Module):
  144. def __init__(self, input_dim=512, hidden_dim=256):
  145. super(HyperNetwork, self).__init__()
  146. self.net = torch.nn.Sequential(
  147. torch.nn.Linear(input_dim, hidden_dim),
  148. torch.nn.ReLU(),
  149. torch.nn.Linear(hidden_dim, 1) # Predicts a scaling factor
  150. )
  151.  
  152. def forward(self, context_embedding):
  153. # context_embedding: tensor of shape (batch, input_dim)
  154. mean_embedding = context_embedding.mean(dim=0, keepdim=True)
  155. scaling = self.net(mean_embedding) # shape (1, 1)
  156. return scaling.squeeze(0) # returns a scalar
  157.  
  158. # 3. Dynamic Weight Adjustment: Neural network–based weight allocator for rewards
  159. class NeuralWeightAllocator(torch.nn.Module):
  160. def __init__(self, num_rewards):
  161. super(NeuralWeightAllocator, self).__init__()
  162. self.net = torch.nn.Linear(num_rewards * 3, num_rewards)
  163.  
  164. def forward(self, reward_history):
  165. # reward_history: tensor of shape (3, num_rewards)
  166. hist_flat = reward_history.flatten().unsqueeze(0) # shape (1, num_rewards*3)
  167. weights = self.net(hist_flat)
  168. return torch.softmax(weights, dim=1).squeeze(0) # shape (num_rewards,)
  169.  
  170. # 4. Auto-Discovered Reward Components: Uses LLM-generated reward templates to evolve reward functions
  171. class RewardEvolution:
  172. def __init__(self, generator_model):
  173. self.generator = generator_model # Text-generation pipeline
  174.  
  175. def generate_new_reward(self, training_history):
  176. history_str = ", ".join(str(r) for r in training_history)
  177. prompt = f"Analyze these training reward values: {history_str}. Propose a multiplicative factor to improve reward calibration."
  178. output = self.generator(prompt, max_length=50, truncation=True)[0]['generated_text']
  179. factor = self._parse_factor(output)
  180. print(f"[RewardEvolution] New calibration factor: {factor}")
  181. return lambda rewards: [r * factor for r in rewards]
  182.  
  183. def _parse_factor(self, text):
  184. matches = re.findall(r"[\d\.]+", text)
  185. if matches:
  186. try:
  187. return float(matches[0])
  188. except:
  189. return 1.0
  190. return 1.0
  191.  
  192. # 5. Dynamic Gradient Accumulation: Adaptive accumulator using EWMA of gradient variance
  193. class AdaptiveAccumulator:
  194. def __init__(self, init_steps=4, alpha=0.3):
  195. self.accum_steps = init_steps
  196. self.alpha = alpha
  197. self.ewma = None
  198.  
  199. def update(self, gradients):
  200. current_var = gradients.var().item() if gradients.numel() > 0 else 0.0
  201. if self.ewma is None:
  202. self.ewma = current_var
  203. else:
  204. self.ewma = self.alpha * current_var + (1 - self.alpha) * self.ewma
  205. # Adjust accumulation steps: lower variance means we can increase steps for smoother updates.
  206. if self.ewma > 0.1:
  207. self.accum_steps = max(2, self.accum_steps - 1)
  208. else:
  209. self.accum_steps = min(8, self.accum_steps + 1)
  210. print(f"[AdaptiveAccumulator] EWMA: {self.ewma:.4f}, Accumulation Steps: {self.accum_steps}")
  211. return self.accum_steps
  212.  
  213. # 6. Selective Activation Recompilation: Activation caching for efficiency.
  214. class EfficientTrainer(GRPOTrainer):
  215. def __init__(self, *args, **kwargs):
  216. super(EfficientTrainer, self).__init__(*args, **kwargs)
  217. self.activation_cache = {}
  218.  
  219. def training_step(self, batch):
  220. with torch.no_grad():
  221. base_out = self.model(**batch, output_hidden_states=True)
  222. if hasattr(base_out, "hidden_states"):
  223. self.activation_cache['hidden'] = base_out.hidden_states
  224. return super(EfficientTrainer, self).training_step(batch)
  225.  
  226. # 7. Curriculum-Driven Multi-Objective Learning: Phase-adaptive curriculum sampler.
  227. class CurriculumSampler:
  228. def __init__(self, dataset):
  229. self.dataset = dataset
  230. self.difficulty_scores = self._calculate_difficulty()
  231.  
  232. def _calculate_difficulty(self):
  233. scores = []
  234. for ex in self.dataset:
  235. score = len(ex["prompt"]) / 100.0
  236. scores.append(score)
  237. return scores
  238.  
  239. def sample_batch(self, phase):
  240. dataset_size = len(self.dataset)
  241. sorted_indices = np.argsort(self.difficulty_scores)
  242. if phase == 0:
  243. idxs = sorted_indices[: dataset_size // 3]
  244. elif phase == 1:
  245. idxs = sorted_indices[dataset_size // 3: 2 * dataset_size // 3]
  246. else:
  247. idxs = sorted_indices[2 * dataset_size // 3:]
  248. return self.dataset.select(list(idxs))
  249.  
  250. # 8. Emergent Skill Probes: Automated capability tests during validation.
  251. class EmergentSkillValidator:
  252. TEST_PROMPTS = {
  253. "counterfactual": "If a problem stated A instead of B, how would your solution change?",
  254. "generalization": "Solve this unseen problem: What is the square root of 256?",
  255. "self_critique": "Identify potential flaws in the following solution: <reasoning>...<answer>...</answer></reasoning>"
  256. }
  257.  
  258. def __init__(self, model):
  259. self.model = model
  260.  
  261. def run_tests(self):
  262. results = {}
  263. for skill, template in self.TEST_PROMPTS.items():
  264. response = self.model.generate(template, SamplingParams(temperature=0.7, max_tokens=100))
  265. results[skill] = self._evaluate_response(skill, response[0].outputs[0].text)
  266. return results
  267.  
  268. def _evaluate_response(self, skill, response):
  269. return len(response) > 20
  270.  
  271. # 9. Enhanced Reward Orchestration: Inherits from base RewardOrchestrator.
  272. class RewardOrchestrator:
  273. def __init__(self, tokenizer, main_model):
  274. self.tokenizer = tokenizer
  275. self.main_model = main_model
  276. self.ref_tokenizer = AutoTokenizer.from_pretrained(REF_MODEL_NAME)
  277. self.ref_model = AutoModelForCausalLM.from_pretrained(REF_MODEL_NAME)
  278. device = "cuda" if torch.cuda.is_available() else "cpu"
  279. self.ref_model.to(device)
  280. self.validator = pipeline(
  281. "text-classification",
  282. model="roberta-base-openai-detector",
  283. device=0 if torch.cuda.is_available() else -1
  284. )
  285.  
  286. def calculate_rewards(self, phase, prompts, completions, answers, distractors):
  287. rewards_dict = {
  288. 'structure': self._structural_reward(completions),
  289. 'contrastive': self._contrastive_reward(completions, answers, distractors),
  290. 'critique': self._critique_reward(completions, answers, phase),
  291. 'correctness': self._correctness_reward(completions, answers),
  292. 'kl': self._kl_reward(prompts)
  293. }
  294. return rewards_dict
  295.  
  296. def _structural_reward(self, completions):
  297. rewards = []
  298. for comp in completions:
  299. has_reasoning = "<reasoning>" in comp and "</reasoning>" in comp and (comp.find("<reasoning>") < comp.find("</reasoning>"))
  300. has_answer = "<answer>" in comp and "</answer>" in comp and (comp.find("<answer>") < comp.find("</answer>"))
  301. has_critique = "<critique>" in comp and "</critique>" in comp and (comp.find("<critique>") < comp.find("</critique>"))
  302. valid = has_reasoning and has_answer
  303. score = 1.0 if valid else -1.0
  304. if has_critique:
  305. score += 0.2 if comp.find("<answer>") < comp.find("<critique>") else -0.1
  306. length_penalty = max(0, (len(comp) - 200) // 50 * 0.1)
  307. rewards.append(score - length_penalty)
  308. return rewards
  309.  
  310. def _contrastive_reward(self, completions, answers, distractors):
  311. rewards = []
  312. for comp, ans, dists in zip(completions, answers, distractors):
  313. comp_val = self._parse_numeric(comp)
  314. ans_val = self._parse_numeric(ans)
  315. if np.isnan(comp_val) or np.isnan(ans_val):
  316. rewards.append(-1.0)
  317. continue
  318. dist_diffs = [abs(comp_val - self._parse_numeric(d)) for d in dists if self._is_number(d)]
  319. min_dist = min(dist_diffs) if dist_diffs else 0.0
  320. diff = abs(comp_val - ans_val)
  321. reward = 2.0 if diff < 0.01 else 1.0 / (1 + diff) - 0.3 * min_dist
  322. rewards.append(max(min(reward, 2.0), -1.0))
  323. return rewards
  324.  
  325. def _critique_reward(self, completions, answers, phase):
  326. rewards = []
  327. for comp, ans in zip(completions, answers):
  328. critique = self._extract_critique(comp)
  329. if not critique:
  330. rewards.append(-1.5 * [0.8, 1.0, 1.2][phase])
  331. continue
  332. valid = self.validator(critique[:512])[0]["label"] == "REAL"
  333. try:
  334. comp_val = self._parse_numeric(comp)
  335. ans_val = self._parse_numeric(ans)
  336. correct = abs(comp_val - ans_val) < 0.01
  337. except:
  338. correct = False
  339. base = 1.0 if valid else -1.5
  340. phase_weight = [0.8, 1.0, 1.2][phase]
  341. rewards.append(base * phase_weight * (1.2 if correct else 0.8))
  342. return rewards
  343.  
  344. def _correctness_reward(self, completions, answers):
  345. rewards = []
  346. for c, a in zip(completions, answers):
  347. try:
  348. if abs(self._parse_numeric(c) - self._parse_numeric(a)) < 0.01:
  349. rewards.append(2.0)
  350. else:
  351. rewards.append(-1.0)
  352. except:
  353. rewards.append(-1.0)
  354. return rewards
  355.  
  356. def _kl_reward(self, prompts):
  357. inputs = self.ref_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LENGTH)
  358. device = "cuda" if torch.cuda.is_available() else "cpu"
  359. inputs = {k: v.to(device) for k, v in inputs.items()}
  360. with torch.no_grad():
  361. base_logits = self.ref_model(**inputs).logits
  362. current_logits = self.main_model(**inputs).logits
  363. kl_div = torch.nn.functional.kl_div(
  364. torch.log_softmax(current_logits, dim=-1),
  365. torch.softmax(base_logits, dim=-1),
  366. reduction='batchmean'
  367. )
  368. return [-kl_div.item()] * len(prompts)
  369.  
  370. def _extract_critique(self, text):
  371. match = re.search(r"<critique>(.*?)</critique>", text, re.DOTALL)
  372. return match.group(1).strip() if match else ""
  373.  
  374. def _parse_numeric(self, text):
  375. try:
  376. m = re.search(r"[-+]?\d*\.?\d+", text)
  377. return float(m.group()) if m else float('nan')
  378. except:
  379. return float('nan')
  380.  
  381. def _is_number(self, s):
  382. try:
  383. float(s)
  384. return True
  385. except:
  386. return False
  387.  
  388. # EnhancedRewardOrchestrator: adds memory retrieval and weight allocation.
  389. class EnhancedRewardOrchestrator(RewardOrchestrator):
  390. def __init__(self, tokenizer, main_model):
  391. super().__init__(tokenizer, main_model)
  392. self.memory = NeuralMemoryBank()
  393. self.weight_allocator = NeuralWeightAllocator(num_rewards=5)
  394.  
  395. def calculate_rewards(self, phase, prompts, completions, answers, distractors):
  396. base_rewards = super().calculate_rewards(phase, prompts, completions, answers, distractors)
  397. # Optionally integrate a memory bonus (for demonstration, we use a small constant bonus)
  398. memory_bonus = 0.1
  399. rewards_list = []
  400. reward_keys = ['structure', 'contrastive', 'critique', 'correctness', 'kl']
  401. for i in range(len(prompts)):
  402. rewards_dict = {k: base_rewards[k][i] + memory_bonus for k in reward_keys}
  403. rewards_list.append(rewards_dict)
  404. return rewards_list
  405.  
  406. # 10. Dynamic LoRA Adapter (base version)
  407. class DynamicLoRA:
  408. def __init__(self, base_model):
  409. self.model = base_model
  410. self.current_rank = INITIAL_LORA_RANK
  411. self._initialize_lora()
  412.  
  413. def _initialize_lora(self):
  414. self.model = FastLanguageModel.get_peft_model(
  415. self.model,
  416. r=INITIAL_LORA_RANK,
  417. lora_alpha=INITIAL_LORA_RANK * 2,
  418. target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
  419. use_gradient_checkpointing=True,
  420. )
  421.  
  422. def expand_rank(self, new_rank):
  423. if new_rank <= self.current_rank:
  424. return
  425. try:
  426. adapter_state = self.model.get_adapter_state()
  427. new_config = {**self.model.peft_config, "r": new_rank, "lora_alpha": new_rank * 2}
  428. self.model = FastLanguageModel.inject_adapter(self.model, new_config)
  429. self.model.load_adapter(adapter_state, strict=False)
  430. self.current_rank = new_rank
  431. print(f"[DynamicLoRA] LoRA rank expanded to {new_rank}")
  432. except Exception as e:
  433. print(f"[DynamicLoRA] Error during LoRA expansion: {e}")
  434.  
  435. # 11. DynamicLoRAWithContext: Uses a hypernetwork for contextual rank adjustment.
  436. class DynamicLoRAWithContext(DynamicLoRA):
  437. def __init__(self, base_model):
  438. super().__init__(base_model)
  439. self.hypernet = HyperNetwork()
  440.  
  441. def contextual_rank_adjustment(self, context_embeddings=None):
  442. if context_embeddings is None:
  443. context_embeddings = torch.randn(1, 512).to(next(self.model.parameters()).device)
  444. scaling_factor = self.hypernet(context_embeddings)
  445. factor = 1 + 0.1 * scaling_factor.item()
  446. new_rank = int(self.current_rank * factor)
  447. new_rank = min(new_rank, self.current_rank + LORA_RANK_INCREMENT)
  448. if new_rank > self.current_rank:
  449. print(f"[DynamicLoRAWithContext] Adjusting rank from {self.current_rank} to {new_rank} based on context")
  450. self.expand_rank(new_rank)
  451.  
  452. # 12. GSM8KProcessor: Processes the GSM8K dataset.
  453. class GSM8KProcessor:
  454. def __init__(self):
  455. self.unit_conversions = {
  456. 'km': (0.621371, 'mi'),
  457. 'hours': (60, 'minutes'),
  458. '$': (100, 'cents'),
  459. }
  460.  
  461. def process_dataset(self):
  462. dataset = load_dataset("gsm8k", "main")["train"]
  463. return dataset.map(self._process_example, remove_columns=dataset.column_names)
  464.  
  465. def _process_example(self, example):
  466. answer = self._extract_answer(example["answer"])
  467. return {
  468. "prompt": f"Solve: {example['question']}\nUse XML structure:",
  469. "answer": answer,
  470. "distractors": self._generate_distractors(answer),
  471. }
  472.  
  473. def _extract_answer(self, solution):
  474. match = re.search(r"\\boxed{([^}]+)}", solution)
  475. if not match:
  476. match = re.search(r"\$\s*([+-]?\d+\.?\d*)", solution)
  477. extracted = match.group(1) if match else "0"
  478. if extracted == "0":
  479. print(f"[GSM8KProcessor] Warning: No valid answer found in solution: {solution}")
  480. return self._normalize_value(extracted)
  481.  
  482. def _normalize_value(self, value_str):
  483. return value_str.replace(",", "").strip()
  484.  
  485. def _generate_distractors(self, answer):
  486. value, unit = self._parse_value_unit(answer)
  487. return [
  488. self._numeric_distractor(value, unit),
  489. self._unit_distractor(value, unit),
  490. self._semantic_distractor(value, unit)
  491. ]
  492.  
  493. def _parse_value_unit(self, text):
  494. match = re.match(r"([+-]?\d+\.?\d*)(.*)", text.strip())
  495. if match:
  496. return float(match.group(1)), match.group(2).strip()
  497. return 0.0, ""
  498.  
  499. def _numeric_distractor(self, value, unit):
  500. variation = value * random.choice([1.2, 0.8, -1])
  501. return f"{variation:.2f}{unit}"
  502.  
  503. def _unit_distractor(self, value, unit):
  504. for pattern, (factor, new_unit) in self.unit_conversions.items():
  505. if pattern in unit:
  506. return f"{value * factor:.2f} {new_unit}"
  507. return f"{value}{random.choice([' m', ' kg', ' s'])}"
  508.  
  509. def _semantic_distractor(self, value, unit):
  510. if unit:
  511. synsets = wordnet.synsets(unit)
  512. lemmas = set(chain.from_iterable([syn.lemma_names() for syn in synsets])) if synsets else set()
  513. if lemmas:
  514. synonym = random.choice(list(lemmas))
  515. return f"approximately {value:.1f} {synonym}"
  516. variations = [
  517. f"approximately {value:.1f} {unit}",
  518. f"around {value:.1f} {unit}",
  519. f"roughly {value:.1f} {unit}",
  520. f"nearly {value:.1f} {unit}"
  521. ]
  522. return random.choice(variations)
  523. return f"~{value:.0f}"
  524.  
  525. # 13. DeepCoralTrainer: Base trainer for DeepSynapse training.
  526. class DeepCoralTrainer:
  527. def __init__(self):
  528. self.base_model, self.tokenizer = FastLanguageModel.from_pretrained(
  529. MODEL_NAME,
  530. max_seq_length=MAX_SEQ_LENGTH,
  531. load_in_4bit=True
  532. )
  533. self.lora_manager = DynamicLoRA(self.base_model)
  534. self.dataset = GSM8KProcessor().process_dataset()
  535. self.reward_system = RewardOrchestrator(self.tokenizer, self.lora_manager.model)
  536. self.trainer = None
  537.  
  538. def configure_training(self):
  539. args = GRPOConfig(
  540. per_device_train_batch_size=4,
  541. gradient_accumulation_steps=4,
  542. max_steps=900,
  543. learning_rate=2e-5,
  544. temperature_scheduler=lambda s: 0.9 - 0.6 * min(1, s / 900),
  545. kl_weight_scheduler=lambda s: PHASE_WEIGHTS['kl'][s // PHASE_TRANSITION_STEPS],
  546. report_to="wandb"
  547. )
  548. self.trainer = GRPOTrainer(
  549. model=self.lora_manager.model,
  550. args=args,
  551. train_dataset=self.dataset,
  552. reward_func=self._phase_aware_reward,
  553. reward_aggregator=self._aggregate_rewards,
  554. )
  555. return self.trainer
  556.  
  557. def _phase_aware_reward(self, prompts, completions, answers, distractors):
  558. phase = min(self.trainer.state.global_step // PHASE_TRANSITION_STEPS, 2)
  559. return self.reward_system.calculate_rewards(phase, prompts, completions, answers, distractors)
  560.  
  561. def _aggregate_rewards(self, phase, rewards):
  562. aggregated = []
  563. for r in rewards:
  564. agg = sum(r[comp] * PHASE_WEIGHTS[comp][phase] for comp in PHASE_WEIGHTS.keys())
  565. aggregated.append(agg)
  566. return aggregated
  567.  
  568. def execute_training(self):
  569. wandb.init(project="DEEPCORAL-X")
  570. trainer = self.configure_training()
  571. try:
  572. for step, batch in enumerate(trainer.dataloader):
  573. current_phase = step // PHASE_TRANSITION_STEPS
  574. new_rank = INITIAL_LORA_RANK + current_phase * LORA_RANK_INCREMENT
  575. if new_rank > self.lora_manager.current_rank:
  576. self.lora_manager.expand_rank(new_rank)
  577. original_lr = trainer.args.learning_rate
  578. trainer.args.learning_rate = original_lr * 0.5
  579. warmup_metrics = trainer.training_step(batch)
  580. wandb.log({"warmup": True, "lr": trainer.args.learning_rate}, step=step)
  581. trainer.args.learning_rate = original_lr
  582. metrics = warmup_metrics
  583. else:
  584. metrics = trainer.training_step(batch)
  585. wandb.log({
  586. "phase": current_phase,
  587. "lora_rank": self.lora_manager.current_rank,
  588. **metrics
  589. }, step=step)
  590. if step % 100 == 0:
  591. self._validation_check()
  592. finally:
  593. self.lora_manager.model.save_lora("final_adapters")
  594. wandb.finish()
  595.  
  596. def _validation_check(self):
  597. sample_prompts = [
  598. "Solve: If a train travels 300 km in 3 hours, what is its speed? Use XML structure:",
  599. "Solve: A store sells apples for $0.50 each. How much do 12 apples cost? Use XML structure:"
  600. ]
  601. sampling_params = SamplingParams(temperature=0.7, max_tokens=200)
  602. for prompt in sample_prompts:
  603. completion = self.lora_manager.model.generate(prompt, sampling_params)
  604. print(f"[Validation] Prompt: {prompt}")
  605. print(f"[Validation] Completion: {completion}")
  606.  
  607. # 14. EnhancedDeepCoralTrainer: Incorporates advanced modules.
  608. class EnhancedDeepCoralTrainer(DeepCoralTrainer):
  609. def __init__(self):
  610. super().__init__()
  611. self.lora_manager = DynamicLoRAWithContext(self.base_model)
  612. self.reward_system = EnhancedRewardOrchestrator(self.tokenizer, self.lora_manager.model)
  613. self.curriculum = CurriculumSampler(self.dataset)
  614. self.skill_validator = EmergentSkillValidator(self.lora_manager.model)
  615. self.grad_accumulator = AdaptiveAccumulator(init_steps=4, alpha=0.3)
  616. self.reward_evolution = RewardEvolution(generator_model=pipeline("text-generation", model=MODEL_NAME, tokenizer=MODEL_NAME))
  617. self.reward_calibrator = NeuralWeightAllocator(num_rewards=5)
  618.  
  619. def configure_training(self):
  620. args = GRPOConfig(
  621. per_device_train_batch_size=4,
  622. gradient_accumulation_steps=self.grad_accumulator.accum_steps,
  623. max_steps=900,
  624. learning_rate=2e-5,
  625. temperature_scheduler=lambda s: 0.9 - 0.6 * min(1, s / 900),
  626. kl_weight_scheduler=lambda s: PHASE_WEIGHTS['kl'][s // PHASE_TRANSITION_STEPS],
  627. report_to="wandb"
  628. )
  629. self.trainer = EfficientTrainer(
  630. model=self.lora_manager.model,
  631. args=args,
  632. train_dataset=self.dataset,
  633. reward_func=self._phase_aware_reward,
  634. reward_aggregator=self._aggregate_rewards,
  635. )
  636. return self.trainer
  637.  
  638. def _phase_aware_reward(self, prompts, completions, answers, distractors):
  639. phase = min(self.trainer.state.global_step // PHASE_TRANSITION_STEPS, 2)
  640. try:
  641. # Simulate context embeddings extraction; replace with real encoder if available.
  642. context_embeddings = torch.randn(1, 512).to(next(self.lora_manager.model.parameters()).device)
  643. except Exception as e:
  644. print(f"[EnhancedDeepCoralTrainer] Error obtaining context embeddings: {e}")
  645. context_embeddings = None
  646. self.lora_manager.contextual_rank_adjustment(context_embeddings)
  647. return self.reward_system.calculate_rewards(phase, prompts, completions, answers, distractors)
  648.  
  649. def _aggregate_rewards(self, phase, rewards):
  650. aggregated = []
  651. for r in rewards:
  652. wandb.log({f"reward_{comp}": r.get(comp, 0) for comp in PHASE_WEIGHTS.keys()},
  653. step=self.trainer.state.global_step)
  654. agg = sum(r[comp] * PHASE_WEIGHTS[comp][phase] for comp in PHASE_WEIGHTS.keys())
  655. aggregated.append(agg)
  656. if self.trainer.state.global_step % 300 == 0 and len(aggregated) >= 3:
  657. evolution_func = self.reward_evolution.generate_new_reward(aggregated)
  658. aggregated = evolution_func(aggregated)
  659. if len(aggregated) >= 3:
  660. try:
  661. rewards_tensor = torch.tensor(aggregated[-3:], dtype=torch.float32)
  662. calibration_factors = self.reward_calibrator(rewards_tensor.unsqueeze(0))
  663. calibrated = [agg * cal for agg, cal in zip(aggregated, calibration_factors.tolist())]
  664. return calibrated
  665. except Exception as e:
  666. print(f"[EnhancedDeepCoralTrainer] Reward calibration error: {e}")
  667. return aggregated
  668.  
  669. def execute_training(self):
  670. wandb.init(project="DEEPCORAL-X")
  671. trainer = self.configure_training()
  672. try:
  673. for step, batch in enumerate(trainer.dataloader):
  674. current_phase = step // PHASE_TRANSITION_STEPS
  675. new_rank = INITIAL_LORA_RANK + current_phase * LORA_RANK_INCREMENT
  676. if new_rank > self.lora_manager.current_rank:
  677. self.lora_manager.expand_rank(new_rank)
  678. original_lr = trainer.args.learning_rate
  679. trainer.args.learning_rate = original_lr * 0.5
  680. warmup_metrics = trainer.training_step(batch)
  681. wandb.log({"warmup": True, "lr": trainer.args.learning_rate}, step=step)
  682. trainer.args.learning_rate = original_lr
  683. metrics = warmup_metrics
  684. else:
  685. metrics = trainer.training_step(batch)
  686. grad_tensor = torch.tensor([v for v in metrics.values() if isinstance(v, (int, float))])
  687. new_accum = self.grad_accumulator.update(grad_tensor)
  688. trainer.args.gradient_accumulation_steps = new_accum
  689. wandb.log({
  690. "phase": current_phase,
  691. "lora_rank": self.lora_manager.current_rank,
  692. **metrics
  693. }, step=step)
  694. if step % 100 == 0:
  695. self._validation_check()
  696. skill_results = self.skill_validator.run_tests()
  697. wandb.log({"skill_probes": skill_results}, step=step)
  698. finally:
  699. self.lora_manager.model.save_lora("final_adapters")
  700. wandb.finish()
  701.  
  702. # -----------------------------------------------------------------------------
  703. # Unit Test Functions
  704. # -----------------------------------------------------------------------------
  705. class DeepCoralTests(unittest.TestCase):
  706. def test_gsm8k_processor(self):
  707. processor = GSM8KProcessor()
  708. sample_solution = r"\boxed{123.45 km}"
  709. answer = processor._extract_answer(sample_solution)
  710. self.assertIn("123.45", answer, "Answer extraction failed")
  711. value, unit = processor._parse_value_unit(answer)
  712. self.assertIsInstance(value, float, "Value parsing failed")
  713.  
  714. def test_dynamic_lora_expansion(self):
  715. base_model, _ = FastLanguageModel.from_pretrained(
  716. MODEL_NAME, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=True
  717. )
  718. lora = DynamicLoRA(base_model)
  719. original_params = sum(p.numel() for p in lora.model.parameters())
  720. lora.expand_rank(INITIAL_LORA_RANK + LORA_RANK_INCREMENT)
  721. new_params = sum(p.numel() for p in lora.model.parameters())
  722. self.assertGreater(new_params, original_params, "LoRA expansion did not increase parameters")
  723.  
  724. def test_reward_orchestrator(self):
  725. dummy_completions = [
  726. "<reasoning>Some reasoning</reasoning><answer>150</answer><critique>Looks REAL</critique>"
  727. ]
  728. dummy_answers = ["150"]
  729. dummy_distractors = [["140", "160", "approximately 150"]]
  730. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
  731. dummy_model = FastLanguageModel.from_pretrained(MODEL_NAME, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=True)[0]
  732. orchestrator = RewardOrchestrator(tokenizer, dummy_model)
  733. rewards = orchestrator.calculate_rewards(phase=1, prompts=["Test prompt"], completions=dummy_completions, answers=dummy_answers, distractors=dummy_distractors)
  734. self.assertIn("structure", rewards, "Reward keys missing")
  735. print("RewardOrchestrator test rewards:", rewards)
  736.  
  737. # -----------------------------------------------------------------------------
  738. # Main Execution
  739. # -----------------------------------------------------------------------------
  740. if __name__ == "__main__":
  741. # Run unit tests.
  742. unittest.main(exit=False)
  743.  
  744. # Execute enhanced training.
  745. EnhancedDeepCoralTrainer().execute_training()
  746.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement