Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- =====================================================================
- DeepSynapse v4.1: Comprehensive Phase-Controlled Training System
- =====================================================================
- DeepSynapse (Dynamic Error-enhanced, Emergent Phase-controlled Self-Optimizing
- Neural Adaptive Processing System Engine) is an advanced reinforcement learning
- framework designed for training language models that can autonomously evolve
- their capabilities. By combining adaptive parameter expansion, multi-objective reward
- optimization, and a phase-controlled curriculum, DeepSynapse achieves robust,
- interpretable, and error-resilient structured reasoning.
- Core Innovations:
- ------------------
- 1. Dynamic LoRA-Head Scaling with Meta-Contextual Adaptation:
- - Implements phase-progressive adapter rank expansion (64 → 128 → 192, etc.)
- with smoothing to ensure stable transitions.
- - Adapts capacity based on context embeddings derived from current batches.
- 2. Triple Distractor Anchoring:
- - Generates multi-modal distractors:
- * Numeric: ±20% variance, sign flips, rounding modifications.
- * Semantic: Context-aware synonym rotations using WordNet.
- * Unit: Multi-dimensional conversion using predefined mappings.
- 3. KL-Temperature Co-Regulation:
- - Uses a cosine-decaying temperature (0.9 → 0.3) along with phase-aligned KL divergence penalties.
- - Helps prevent reward hacking and maintains stable output distributions.
- 4. Reinforced Critique Validation:
- - Extracts dedicated <critique> blocks and evaluates them via a RoBERTa-based classifier.
- - Applies phase-dependent penalties if self-critique does not match solution correctness.
- 5. Phase-Controlled Curriculum & Component Locking:
- - Three-stage training curriculum:
- * Phase 0: Structural Compliance.
- * Phase 1: Reasoning Validation.
- * Phase 2: Precision Refinement.
- - Dynamically activates/deactivates reward components based on phase.
- 6. Omnidirectional Reward Fusion & Calibration:
- - Computes a 5-dimensional reward vector (Structure, Contrastive, Critique, Correctness, KL).
- - Calibrates raw rewards with a neural weight allocator and evolves reward functions based on training history.
- 7. XML Structural Guardian:
- - Enforces strict XML formatting (<reasoning>, <answer>, and <critique> tags).
- - Applies dynamic length penalties to discourage verbosity.
- 8. Integrated Performance Monitoring:
- - Fully integrated with Weights & Biases (W&B) for real-time telemetry and debugging.
- Advanced Roadmap:
- -----------------
- Future evolutions include memory-augmented networks for long-term retention, dynamic gradient accumulation based on advanced metrics, and a fully self-evolving reward system.
- Expected Emergent Capabilities:
- --------------------------------
- - Enhanced counterfactual reasoning, robust self-debugging, and superior zero-shot problem-solving.
- """
- import re
- import torch
- import wandb
- import random
- import numpy as np
- import unittest
- from datasets import load_dataset, Dataset
- from trl import GRPOConfig, GRPOTrainer
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
- from unsloth import FastLanguageModel, is_bfloat16_supported
- from vllm import SamplingParams
- # Additional Imports for Distractor Generation
- import nltk
- try:
- nltk.data.find('corpora/wordnet')
- except LookupError:
- nltk.download('wordnet')
- from nltk.corpus import wordnet
- from itertools import chain
- # -----------------------------------------------------------------------------
- # Configuration Constants
- # -----------------------------------------------------------------------------
- MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
- REF_MODEL_NAME = "Qwen/Qwen2.5-3B" # Reference model for KL divergence
- MAX_SEQ_LENGTH = 2048
- INITIAL_LORA_RANK = 64
- LORA_RANK_INCREMENT = 64 # Progression: 64 -> 128 -> 192, etc.
- PHASE_TRANSITION_STEPS = 300
- # Phase-specific Reward Weights for 3 phases (0, 1, 2)
- PHASE_WEIGHTS = {
- 'structure': [0.6, 0.3, 0.1],
- 'contrastive': [0.0, 0.4, 0.2],
- 'critique': [0.1, 0.2, 0.3],
- 'correctness': [0.1, 0.3, 0.6],
- 'kl': [0.0, 0.1, 0.2],
- }
- SYSTEM_PROMPT = """Respond using structured reasoning followed by a concise answer:
- <reasoning>
- Step-by-step logical explanation...
- </reasoning>
- <answer>
- Final numerical answer only
- </answer>
- <critique>
- Your self-critique here.
- </critique>"""
- # -----------------------------------------------------------------------------
- # Advanced Components
- # -----------------------------------------------------------------------------
- # 1. Hybrid Modular Memory: Memory-augmented neural network (MANN)
- class NeuralMemoryBank:
- def __init__(self, model_dim=1024):
- self.memory = [] # Stores (key, value) pairs as tensors
- self.attention = torch.nn.MultiheadAttention(embed_dim=model_dim, num_heads=4)
- def retrieve(self, query, k=3):
- # query: tensor of shape (embed_dim,)
- query = query.unsqueeze(0) # (1, embed_dim)
- if not self.memory:
- return query.squeeze(0)
- keys = torch.stack([m[0] for m in self.memory]) # (N, embed_dim)
- values = torch.stack([m[1] for m in self.memory]) # (N, embed_dim)
- # Reshape keys and values to (N, 1, embed_dim)
- keys = keys.unsqueeze(1)
- values = values.unsqueeze(1)
- query = query.unsqueeze(0) # (1, 1, embed_dim)
- attn_output, _ = self.attention(query, keys, values)
- return attn_output.squeeze(0)[:k]
- def store(self, key, value):
- self.memory.append((key.detach(), value.detach()))
- if len(self.memory) > 1000: # Use FIFO when memory is full
- self.memory.pop(0)
- # 2. Meta-Contextual Adaptation: Lightweight hypernetwork for LoRA rank scaling
- class HyperNetwork(torch.nn.Module):
- def __init__(self, input_dim=512, hidden_dim=256):
- super(HyperNetwork, self).__init__()
- self.net = torch.nn.Sequential(
- torch.nn.Linear(input_dim, hidden_dim),
- torch.nn.ReLU(),
- torch.nn.Linear(hidden_dim, 1) # Predicts a scaling factor
- )
- def forward(self, context_embedding):
- # context_embedding: tensor of shape (batch, input_dim)
- mean_embedding = context_embedding.mean(dim=0, keepdim=True)
- scaling = self.net(mean_embedding) # shape (1, 1)
- return scaling.squeeze(0) # returns a scalar
- # 3. Dynamic Weight Adjustment: Neural network–based weight allocator for rewards
- class NeuralWeightAllocator(torch.nn.Module):
- def __init__(self, num_rewards):
- super(NeuralWeightAllocator, self).__init__()
- self.net = torch.nn.Linear(num_rewards * 3, num_rewards)
- def forward(self, reward_history):
- # reward_history: tensor of shape (3, num_rewards)
- hist_flat = reward_history.flatten().unsqueeze(0) # shape (1, num_rewards*3)
- weights = self.net(hist_flat)
- return torch.softmax(weights, dim=1).squeeze(0) # shape (num_rewards,)
- # 4. Auto-Discovered Reward Components: Uses LLM-generated reward templates to evolve reward functions
- class RewardEvolution:
- def __init__(self, generator_model):
- self.generator = generator_model # Text-generation pipeline
- def generate_new_reward(self, training_history):
- history_str = ", ".join(str(r) for r in training_history)
- prompt = f"Analyze these training reward values: {history_str}. Propose a multiplicative factor to improve reward calibration."
- output = self.generator(prompt, max_length=50, truncation=True)[0]['generated_text']
- factor = self._parse_factor(output)
- print(f"[RewardEvolution] New calibration factor: {factor}")
- return lambda rewards: [r * factor for r in rewards]
- def _parse_factor(self, text):
- matches = re.findall(r"[\d\.]+", text)
- if matches:
- try:
- return float(matches[0])
- except:
- return 1.0
- return 1.0
- # 5. Dynamic Gradient Accumulation: Adaptive accumulator using EWMA of gradient variance
- class AdaptiveAccumulator:
- def __init__(self, init_steps=4, alpha=0.3):
- self.accum_steps = init_steps
- self.alpha = alpha
- self.ewma = None
- def update(self, gradients):
- current_var = gradients.var().item() if gradients.numel() > 0 else 0.0
- if self.ewma is None:
- self.ewma = current_var
- else:
- self.ewma = self.alpha * current_var + (1 - self.alpha) * self.ewma
- # Adjust accumulation steps: lower variance means we can increase steps for smoother updates.
- if self.ewma > 0.1:
- self.accum_steps = max(2, self.accum_steps - 1)
- else:
- self.accum_steps = min(8, self.accum_steps + 1)
- print(f"[AdaptiveAccumulator] EWMA: {self.ewma:.4f}, Accumulation Steps: {self.accum_steps}")
- return self.accum_steps
- # 6. Selective Activation Recompilation: Activation caching for efficiency.
- class EfficientTrainer(GRPOTrainer):
- def __init__(self, *args, **kwargs):
- super(EfficientTrainer, self).__init__(*args, **kwargs)
- self.activation_cache = {}
- def training_step(self, batch):
- with torch.no_grad():
- base_out = self.model(**batch, output_hidden_states=True)
- if hasattr(base_out, "hidden_states"):
- self.activation_cache['hidden'] = base_out.hidden_states
- return super(EfficientTrainer, self).training_step(batch)
- # 7. Curriculum-Driven Multi-Objective Learning: Phase-adaptive curriculum sampler.
- class CurriculumSampler:
- def __init__(self, dataset):
- self.dataset = dataset
- self.difficulty_scores = self._calculate_difficulty()
- def _calculate_difficulty(self):
- scores = []
- for ex in self.dataset:
- score = len(ex["prompt"]) / 100.0
- scores.append(score)
- return scores
- def sample_batch(self, phase):
- dataset_size = len(self.dataset)
- sorted_indices = np.argsort(self.difficulty_scores)
- if phase == 0:
- idxs = sorted_indices[: dataset_size // 3]
- elif phase == 1:
- idxs = sorted_indices[dataset_size // 3: 2 * dataset_size // 3]
- else:
- idxs = sorted_indices[2 * dataset_size // 3:]
- return self.dataset.select(list(idxs))
- # 8. Emergent Skill Probes: Automated capability tests during validation.
- class EmergentSkillValidator:
- TEST_PROMPTS = {
- "counterfactual": "If a problem stated A instead of B, how would your solution change?",
- "generalization": "Solve this unseen problem: What is the square root of 256?",
- "self_critique": "Identify potential flaws in the following solution: <reasoning>...<answer>...</answer></reasoning>"
- }
- def __init__(self, model):
- self.model = model
- def run_tests(self):
- results = {}
- for skill, template in self.TEST_PROMPTS.items():
- response = self.model.generate(template, SamplingParams(temperature=0.7, max_tokens=100))
- results[skill] = self._evaluate_response(skill, response[0].outputs[0].text)
- return results
- def _evaluate_response(self, skill, response):
- return len(response) > 20
- # 9. Enhanced Reward Orchestration: Inherits from base RewardOrchestrator.
- class RewardOrchestrator:
- def __init__(self, tokenizer, main_model):
- self.tokenizer = tokenizer
- self.main_model = main_model
- self.ref_tokenizer = AutoTokenizer.from_pretrained(REF_MODEL_NAME)
- self.ref_model = AutoModelForCausalLM.from_pretrained(REF_MODEL_NAME)
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.ref_model.to(device)
- self.validator = pipeline(
- "text-classification",
- model="roberta-base-openai-detector",
- device=0 if torch.cuda.is_available() else -1
- )
- def calculate_rewards(self, phase, prompts, completions, answers, distractors):
- rewards_dict = {
- 'structure': self._structural_reward(completions),
- 'contrastive': self._contrastive_reward(completions, answers, distractors),
- 'critique': self._critique_reward(completions, answers, phase),
- 'correctness': self._correctness_reward(completions, answers),
- 'kl': self._kl_reward(prompts)
- }
- return rewards_dict
- def _structural_reward(self, completions):
- rewards = []
- for comp in completions:
- has_reasoning = "<reasoning>" in comp and "</reasoning>" in comp and (comp.find("<reasoning>") < comp.find("</reasoning>"))
- has_answer = "<answer>" in comp and "</answer>" in comp and (comp.find("<answer>") < comp.find("</answer>"))
- has_critique = "<critique>" in comp and "</critique>" in comp and (comp.find("<critique>") < comp.find("</critique>"))
- valid = has_reasoning and has_answer
- score = 1.0 if valid else -1.0
- if has_critique:
- score += 0.2 if comp.find("<answer>") < comp.find("<critique>") else -0.1
- length_penalty = max(0, (len(comp) - 200) // 50 * 0.1)
- rewards.append(score - length_penalty)
- return rewards
- def _contrastive_reward(self, completions, answers, distractors):
- rewards = []
- for comp, ans, dists in zip(completions, answers, distractors):
- comp_val = self._parse_numeric(comp)
- ans_val = self._parse_numeric(ans)
- if np.isnan(comp_val) or np.isnan(ans_val):
- rewards.append(-1.0)
- continue
- dist_diffs = [abs(comp_val - self._parse_numeric(d)) for d in dists if self._is_number(d)]
- min_dist = min(dist_diffs) if dist_diffs else 0.0
- diff = abs(comp_val - ans_val)
- reward = 2.0 if diff < 0.01 else 1.0 / (1 + diff) - 0.3 * min_dist
- rewards.append(max(min(reward, 2.0), -1.0))
- return rewards
- def _critique_reward(self, completions, answers, phase):
- rewards = []
- for comp, ans in zip(completions, answers):
- critique = self._extract_critique(comp)
- if not critique:
- rewards.append(-1.5 * [0.8, 1.0, 1.2][phase])
- continue
- valid = self.validator(critique[:512])[0]["label"] == "REAL"
- try:
- comp_val = self._parse_numeric(comp)
- ans_val = self._parse_numeric(ans)
- correct = abs(comp_val - ans_val) < 0.01
- except:
- correct = False
- base = 1.0 if valid else -1.5
- phase_weight = [0.8, 1.0, 1.2][phase]
- rewards.append(base * phase_weight * (1.2 if correct else 0.8))
- return rewards
- def _correctness_reward(self, completions, answers):
- rewards = []
- for c, a in zip(completions, answers):
- try:
- if abs(self._parse_numeric(c) - self._parse_numeric(a)) < 0.01:
- rewards.append(2.0)
- else:
- rewards.append(-1.0)
- except:
- rewards.append(-1.0)
- return rewards
- def _kl_reward(self, prompts):
- inputs = self.ref_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LENGTH)
- device = "cuda" if torch.cuda.is_available() else "cpu"
- inputs = {k: v.to(device) for k, v in inputs.items()}
- with torch.no_grad():
- base_logits = self.ref_model(**inputs).logits
- current_logits = self.main_model(**inputs).logits
- kl_div = torch.nn.functional.kl_div(
- torch.log_softmax(current_logits, dim=-1),
- torch.softmax(base_logits, dim=-1),
- reduction='batchmean'
- )
- return [-kl_div.item()] * len(prompts)
- def _extract_critique(self, text):
- match = re.search(r"<critique>(.*?)</critique>", text, re.DOTALL)
- return match.group(1).strip() if match else ""
- def _parse_numeric(self, text):
- try:
- m = re.search(r"[-+]?\d*\.?\d+", text)
- return float(m.group()) if m else float('nan')
- except:
- return float('nan')
- def _is_number(self, s):
- try:
- float(s)
- return True
- except:
- return False
- # EnhancedRewardOrchestrator: adds memory retrieval and weight allocation.
- class EnhancedRewardOrchestrator(RewardOrchestrator):
- def __init__(self, tokenizer, main_model):
- super().__init__(tokenizer, main_model)
- self.memory = NeuralMemoryBank()
- self.weight_allocator = NeuralWeightAllocator(num_rewards=5)
- def calculate_rewards(self, phase, prompts, completions, answers, distractors):
- base_rewards = super().calculate_rewards(phase, prompts, completions, answers, distractors)
- # Optionally integrate a memory bonus (for demonstration, we use a small constant bonus)
- memory_bonus = 0.1
- rewards_list = []
- reward_keys = ['structure', 'contrastive', 'critique', 'correctness', 'kl']
- for i in range(len(prompts)):
- rewards_dict = {k: base_rewards[k][i] + memory_bonus for k in reward_keys}
- rewards_list.append(rewards_dict)
- return rewards_list
- # 10. Dynamic LoRA Adapter (base version)
- class DynamicLoRA:
- def __init__(self, base_model):
- self.model = base_model
- self.current_rank = INITIAL_LORA_RANK
- self._initialize_lora()
- def _initialize_lora(self):
- self.model = FastLanguageModel.get_peft_model(
- self.model,
- r=INITIAL_LORA_RANK,
- lora_alpha=INITIAL_LORA_RANK * 2,
- target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
- use_gradient_checkpointing=True,
- )
- def expand_rank(self, new_rank):
- if new_rank <= self.current_rank:
- return
- try:
- adapter_state = self.model.get_adapter_state()
- new_config = {**self.model.peft_config, "r": new_rank, "lora_alpha": new_rank * 2}
- self.model = FastLanguageModel.inject_adapter(self.model, new_config)
- self.model.load_adapter(adapter_state, strict=False)
- self.current_rank = new_rank
- print(f"[DynamicLoRA] LoRA rank expanded to {new_rank}")
- except Exception as e:
- print(f"[DynamicLoRA] Error during LoRA expansion: {e}")
- # 11. DynamicLoRAWithContext: Uses a hypernetwork for contextual rank adjustment.
- class DynamicLoRAWithContext(DynamicLoRA):
- def __init__(self, base_model):
- super().__init__(base_model)
- self.hypernet = HyperNetwork()
- def contextual_rank_adjustment(self, context_embeddings=None):
- if context_embeddings is None:
- context_embeddings = torch.randn(1, 512).to(next(self.model.parameters()).device)
- scaling_factor = self.hypernet(context_embeddings)
- factor = 1 + 0.1 * scaling_factor.item()
- new_rank = int(self.current_rank * factor)
- new_rank = min(new_rank, self.current_rank + LORA_RANK_INCREMENT)
- if new_rank > self.current_rank:
- print(f"[DynamicLoRAWithContext] Adjusting rank from {self.current_rank} to {new_rank} based on context")
- self.expand_rank(new_rank)
- # 12. GSM8KProcessor: Processes the GSM8K dataset.
- class GSM8KProcessor:
- def __init__(self):
- self.unit_conversions = {
- 'km': (0.621371, 'mi'),
- 'hours': (60, 'minutes'),
- '$': (100, 'cents'),
- }
- def process_dataset(self):
- dataset = load_dataset("gsm8k", "main")["train"]
- return dataset.map(self._process_example, remove_columns=dataset.column_names)
- def _process_example(self, example):
- answer = self._extract_answer(example["answer"])
- return {
- "prompt": f"Solve: {example['question']}\nUse XML structure:",
- "answer": answer,
- "distractors": self._generate_distractors(answer),
- }
- def _extract_answer(self, solution):
- match = re.search(r"\\boxed{([^}]+)}", solution)
- if not match:
- match = re.search(r"\$\s*([+-]?\d+\.?\d*)", solution)
- extracted = match.group(1) if match else "0"
- if extracted == "0":
- print(f"[GSM8KProcessor] Warning: No valid answer found in solution: {solution}")
- return self._normalize_value(extracted)
- def _normalize_value(self, value_str):
- return value_str.replace(",", "").strip()
- def _generate_distractors(self, answer):
- value, unit = self._parse_value_unit(answer)
- return [
- self._numeric_distractor(value, unit),
- self._unit_distractor(value, unit),
- self._semantic_distractor(value, unit)
- ]
- def _parse_value_unit(self, text):
- match = re.match(r"([+-]?\d+\.?\d*)(.*)", text.strip())
- if match:
- return float(match.group(1)), match.group(2).strip()
- return 0.0, ""
- def _numeric_distractor(self, value, unit):
- variation = value * random.choice([1.2, 0.8, -1])
- return f"{variation:.2f}{unit}"
- def _unit_distractor(self, value, unit):
- for pattern, (factor, new_unit) in self.unit_conversions.items():
- if pattern in unit:
- return f"{value * factor:.2f} {new_unit}"
- return f"{value}{random.choice([' m', ' kg', ' s'])}"
- def _semantic_distractor(self, value, unit):
- if unit:
- synsets = wordnet.synsets(unit)
- lemmas = set(chain.from_iterable([syn.lemma_names() for syn in synsets])) if synsets else set()
- if lemmas:
- synonym = random.choice(list(lemmas))
- return f"approximately {value:.1f} {synonym}"
- variations = [
- f"approximately {value:.1f} {unit}",
- f"around {value:.1f} {unit}",
- f"roughly {value:.1f} {unit}",
- f"nearly {value:.1f} {unit}"
- ]
- return random.choice(variations)
- return f"~{value:.0f}"
- # 13. DeepCoralTrainer: Base trainer for DeepSynapse training.
- class DeepCoralTrainer:
- def __init__(self):
- self.base_model, self.tokenizer = FastLanguageModel.from_pretrained(
- MODEL_NAME,
- max_seq_length=MAX_SEQ_LENGTH,
- load_in_4bit=True
- )
- self.lora_manager = DynamicLoRA(self.base_model)
- self.dataset = GSM8KProcessor().process_dataset()
- self.reward_system = RewardOrchestrator(self.tokenizer, self.lora_manager.model)
- self.trainer = None
- def configure_training(self):
- args = GRPOConfig(
- per_device_train_batch_size=4,
- gradient_accumulation_steps=4,
- max_steps=900,
- learning_rate=2e-5,
- temperature_scheduler=lambda s: 0.9 - 0.6 * min(1, s / 900),
- kl_weight_scheduler=lambda s: PHASE_WEIGHTS['kl'][s // PHASE_TRANSITION_STEPS],
- report_to="wandb"
- )
- self.trainer = GRPOTrainer(
- model=self.lora_manager.model,
- args=args,
- train_dataset=self.dataset,
- reward_func=self._phase_aware_reward,
- reward_aggregator=self._aggregate_rewards,
- )
- return self.trainer
- def _phase_aware_reward(self, prompts, completions, answers, distractors):
- phase = min(self.trainer.state.global_step // PHASE_TRANSITION_STEPS, 2)
- return self.reward_system.calculate_rewards(phase, prompts, completions, answers, distractors)
- def _aggregate_rewards(self, phase, rewards):
- aggregated = []
- for r in rewards:
- agg = sum(r[comp] * PHASE_WEIGHTS[comp][phase] for comp in PHASE_WEIGHTS.keys())
- aggregated.append(agg)
- return aggregated
- def execute_training(self):
- wandb.init(project="DEEPCORAL-X")
- trainer = self.configure_training()
- try:
- for step, batch in enumerate(trainer.dataloader):
- current_phase = step // PHASE_TRANSITION_STEPS
- new_rank = INITIAL_LORA_RANK + current_phase * LORA_RANK_INCREMENT
- if new_rank > self.lora_manager.current_rank:
- self.lora_manager.expand_rank(new_rank)
- original_lr = trainer.args.learning_rate
- trainer.args.learning_rate = original_lr * 0.5
- warmup_metrics = trainer.training_step(batch)
- wandb.log({"warmup": True, "lr": trainer.args.learning_rate}, step=step)
- trainer.args.learning_rate = original_lr
- metrics = warmup_metrics
- else:
- metrics = trainer.training_step(batch)
- wandb.log({
- "phase": current_phase,
- "lora_rank": self.lora_manager.current_rank,
- **metrics
- }, step=step)
- if step % 100 == 0:
- self._validation_check()
- finally:
- self.lora_manager.model.save_lora("final_adapters")
- wandb.finish()
- def _validation_check(self):
- sample_prompts = [
- "Solve: If a train travels 300 km in 3 hours, what is its speed? Use XML structure:",
- "Solve: A store sells apples for $0.50 each. How much do 12 apples cost? Use XML structure:"
- ]
- sampling_params = SamplingParams(temperature=0.7, max_tokens=200)
- for prompt in sample_prompts:
- completion = self.lora_manager.model.generate(prompt, sampling_params)
- print(f"[Validation] Prompt: {prompt}")
- print(f"[Validation] Completion: {completion}")
- # 14. EnhancedDeepCoralTrainer: Incorporates advanced modules.
- class EnhancedDeepCoralTrainer(DeepCoralTrainer):
- def __init__(self):
- super().__init__()
- self.lora_manager = DynamicLoRAWithContext(self.base_model)
- self.reward_system = EnhancedRewardOrchestrator(self.tokenizer, self.lora_manager.model)
- self.curriculum = CurriculumSampler(self.dataset)
- self.skill_validator = EmergentSkillValidator(self.lora_manager.model)
- self.grad_accumulator = AdaptiveAccumulator(init_steps=4, alpha=0.3)
- self.reward_evolution = RewardEvolution(generator_model=pipeline("text-generation", model=MODEL_NAME, tokenizer=MODEL_NAME))
- self.reward_calibrator = NeuralWeightAllocator(num_rewards=5)
- def configure_training(self):
- args = GRPOConfig(
- per_device_train_batch_size=4,
- gradient_accumulation_steps=self.grad_accumulator.accum_steps,
- max_steps=900,
- learning_rate=2e-5,
- temperature_scheduler=lambda s: 0.9 - 0.6 * min(1, s / 900),
- kl_weight_scheduler=lambda s: PHASE_WEIGHTS['kl'][s // PHASE_TRANSITION_STEPS],
- report_to="wandb"
- )
- self.trainer = EfficientTrainer(
- model=self.lora_manager.model,
- args=args,
- train_dataset=self.dataset,
- reward_func=self._phase_aware_reward,
- reward_aggregator=self._aggregate_rewards,
- )
- return self.trainer
- def _phase_aware_reward(self, prompts, completions, answers, distractors):
- phase = min(self.trainer.state.global_step // PHASE_TRANSITION_STEPS, 2)
- try:
- # Simulate context embeddings extraction; replace with real encoder if available.
- context_embeddings = torch.randn(1, 512).to(next(self.lora_manager.model.parameters()).device)
- except Exception as e:
- print(f"[EnhancedDeepCoralTrainer] Error obtaining context embeddings: {e}")
- context_embeddings = None
- self.lora_manager.contextual_rank_adjustment(context_embeddings)
- return self.reward_system.calculate_rewards(phase, prompts, completions, answers, distractors)
- def _aggregate_rewards(self, phase, rewards):
- aggregated = []
- for r in rewards:
- wandb.log({f"reward_{comp}": r.get(comp, 0) for comp in PHASE_WEIGHTS.keys()},
- step=self.trainer.state.global_step)
- agg = sum(r[comp] * PHASE_WEIGHTS[comp][phase] for comp in PHASE_WEIGHTS.keys())
- aggregated.append(agg)
- if self.trainer.state.global_step % 300 == 0 and len(aggregated) >= 3:
- evolution_func = self.reward_evolution.generate_new_reward(aggregated)
- aggregated = evolution_func(aggregated)
- if len(aggregated) >= 3:
- try:
- rewards_tensor = torch.tensor(aggregated[-3:], dtype=torch.float32)
- calibration_factors = self.reward_calibrator(rewards_tensor.unsqueeze(0))
- calibrated = [agg * cal for agg, cal in zip(aggregated, calibration_factors.tolist())]
- return calibrated
- except Exception as e:
- print(f"[EnhancedDeepCoralTrainer] Reward calibration error: {e}")
- return aggregated
- def execute_training(self):
- wandb.init(project="DEEPCORAL-X")
- trainer = self.configure_training()
- try:
- for step, batch in enumerate(trainer.dataloader):
- current_phase = step // PHASE_TRANSITION_STEPS
- new_rank = INITIAL_LORA_RANK + current_phase * LORA_RANK_INCREMENT
- if new_rank > self.lora_manager.current_rank:
- self.lora_manager.expand_rank(new_rank)
- original_lr = trainer.args.learning_rate
- trainer.args.learning_rate = original_lr * 0.5
- warmup_metrics = trainer.training_step(batch)
- wandb.log({"warmup": True, "lr": trainer.args.learning_rate}, step=step)
- trainer.args.learning_rate = original_lr
- metrics = warmup_metrics
- else:
- metrics = trainer.training_step(batch)
- grad_tensor = torch.tensor([v for v in metrics.values() if isinstance(v, (int, float))])
- new_accum = self.grad_accumulator.update(grad_tensor)
- trainer.args.gradient_accumulation_steps = new_accum
- wandb.log({
- "phase": current_phase,
- "lora_rank": self.lora_manager.current_rank,
- **metrics
- }, step=step)
- if step % 100 == 0:
- self._validation_check()
- skill_results = self.skill_validator.run_tests()
- wandb.log({"skill_probes": skill_results}, step=step)
- finally:
- self.lora_manager.model.save_lora("final_adapters")
- wandb.finish()
- # -----------------------------------------------------------------------------
- # Unit Test Functions
- # -----------------------------------------------------------------------------
- class DeepCoralTests(unittest.TestCase):
- def test_gsm8k_processor(self):
- processor = GSM8KProcessor()
- sample_solution = r"\boxed{123.45 km}"
- answer = processor._extract_answer(sample_solution)
- self.assertIn("123.45", answer, "Answer extraction failed")
- value, unit = processor._parse_value_unit(answer)
- self.assertIsInstance(value, float, "Value parsing failed")
- def test_dynamic_lora_expansion(self):
- base_model, _ = FastLanguageModel.from_pretrained(
- MODEL_NAME, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=True
- )
- lora = DynamicLoRA(base_model)
- original_params = sum(p.numel() for p in lora.model.parameters())
- lora.expand_rank(INITIAL_LORA_RANK + LORA_RANK_INCREMENT)
- new_params = sum(p.numel() for p in lora.model.parameters())
- self.assertGreater(new_params, original_params, "LoRA expansion did not increase parameters")
- def test_reward_orchestrator(self):
- dummy_completions = [
- "<reasoning>Some reasoning</reasoning><answer>150</answer><critique>Looks REAL</critique>"
- ]
- dummy_answers = ["150"]
- dummy_distractors = [["140", "160", "approximately 150"]]
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
- dummy_model = FastLanguageModel.from_pretrained(MODEL_NAME, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=True)[0]
- orchestrator = RewardOrchestrator(tokenizer, dummy_model)
- rewards = orchestrator.calculate_rewards(phase=1, prompts=["Test prompt"], completions=dummy_completions, answers=dummy_answers, distractors=dummy_distractors)
- self.assertIn("structure", rewards, "Reward keys missing")
- print("RewardOrchestrator test rewards:", rewards)
- # -----------------------------------------------------------------------------
- # Main Execution
- # -----------------------------------------------------------------------------
- if __name__ == "__main__":
- # Run unit tests.
- unittest.main(exit=False)
- # Execute enhanced training.
- EnhancedDeepCoralTrainer().execute_training()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement