Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- """
- A simplified GRPO training script for Qwen2-VL.
- This script:
- - Loads a dataset (assumed to have a "problem" and "solution" field).
- - Formats each example as a conversation prompt.
- - Loads the Qwen2-VL model and its tokenizer.
- - Creates a “reference” model (a copy of the loaded model) used for KL regularization.
- - Runs a custom GRPO training loop:
- * For each batch, generates multiple completions.
- * Computes per-token log probabilities from both the current and reference model.
- * Computes a (very simple) reward per example.
- * Computes a GRPO loss as a (negative) advantage–weighted log probability plus a KL penalty.
- - Saves the final model.
- IMPORTANT:
- – This script is intentionally simplified.
- – Many details from the original repo (e.g. image support, multi‐reward functions, accelerator handling) are removed.
- – If you need more “real” math verification or additional reward functions, please update the reward functions accordingly.
- Adjust the placeholder configuration values near the top as needed.
- """
- import os
- import torch
- from torch.utils.data import DataLoader
- from torch.optim import AdamW
- from datasets import load_dataset
- from transformers import (
- Qwen2VLForConditionalGeneration,
- AutoTokenizer,
- )
- # ================================
- # CONFIGURATION (EDIT THESE VALUES)
- # ================================
- # Dataset & model
- DATASET_NAME = "your_dataset_name" # e.g., "HuggingFaceH4/YourDataset"
- DATASET_CONFIG = "default" # if needed; otherwise, leave as "default"
- TRAIN_SPLIT = "train" # training split name
- # (Assume the dataset examples have at least the keys "problem" and "solution")
- MODEL_NAME = "your-model-org/Qwen2-VL" # placeholder model name (must be a Qwen2-VL branch)
- OUTPUT_DIR = "./grpo_trained_model" # where to save the model
- # Training hyperparameters
- NUM_EPOCHS = 3
- BATCH_SIZE = 2 # adjust for your GPU memory
- LEARNING_RATE = 5e-5
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
- # GRPO-specific hyperparameters
- MAX_PROMPT_LENGTH = 512 # maximum tokens for the prompt
- MAX_COMPLETION_LENGTH = 128 # maximum new tokens to generate per prompt
- NUM_GENERATIONS = 3 # number of generations per prompt for GRPO
- BETA = 0.1 # KL regularization coefficient
- GEN_TEMPERATURE = 1.0 # temperature for generation
- # A simple system prompt (used to “prime” the conversation)
- SYSTEM_PROMPT = (
- "A conversation between User and Assistant. The assistant is expected to think through the problem "
- "and then answer. The reasoning is enclosed within <think>...</think> and the final answer within <answer>...</answer>."
- )
- # ================================
- # UTILITY FUNCTIONS & CLASSES
- # ================================
- def accuracy_reward_simple(completion: str, solution: str) -> float:
- """
- A very simple reward function: returns 1.0 if the (stripped) solution text appears
- in the generated completion; else returns 0.0.
- """
- return 1.0 if solution.strip() in completion.strip() else 0.0
- def make_conversation(example):
- """
- Formats a dataset example into a conversation.
- Assumes each example has a "problem" and a "solution".
- """
- # Here we simply prepend a system prompt to the problem.
- prompt = f"{SYSTEM_PROMPT}\nUser: {example['problem']}\nAssistant:"
- # The solution (used later for reward computation) remains untouched.
- return {"prompt": prompt, "solution": example["solution"]}
- class SimpleGRPOTrainer:
- """
- A simplified GRPO trainer that:
- - For each batch, generates multiple completions per prompt.
- - Computes log probabilities (via a forward pass) on both the current model and a reference model.
- - Computes a simple reward (via an accuracy check).
- - Computes a GRPO loss (negative advantage * logprob plus a KL penalty).
- """
- def __init__(self, model, ref_model, tokenizer, train_dataset, optimizer, device,
- num_generations, max_prompt_length, max_completion_length, beta, gen_temperature):
- self.model = model
- self.ref_model = ref_model
- self.tokenizer = tokenizer
- self.train_dataset = train_dataset
- self.optimizer = optimizer
- self.device = device
- self.num_generations = num_generations
- self.max_prompt_length = max_prompt_length
- self.max_completion_length = max_completion_length
- self.beta = beta
- self.gen_temperature = gen_temperature
- def compute_loss(self, prompts, solutions):
- """
- For a list of prompt strings and corresponding solution strings:
- 1. Tokenize the prompts.
- 2. For each prompt, generate NUM_GENERATIONS completions.
- 3. Compute per-sequence log probabilities for each completion under both the current model and the reference.
- 4. Compute a simple reward per example (averaging over generations).
- 5. Compute an “advantage” as (reward - baseline) and form a loss.
- """
- # Tokenize prompts (truncate/pad to max_prompt_length)
- inputs = self.tokenizer(
- prompts,
- return_tensors="pt",
- padding=True,
- truncation=True,
- max_length=self.max_prompt_length,
- ).to(self.device)
- batch_size = inputs.input_ids.size(0)
- # Generate NUM_GENERATIONS completions per prompt.
- all_gen_ids = []
- for _ in range(self.num_generations):
- # Note: we use do_sample=True to allow stochastic generation.
- gen_ids = self.model.generate(
- **inputs,
- max_new_tokens=self.max_completion_length,
- do_sample=True,
- temperature=self.gen_temperature,
- )
- all_gen_ids.append(gen_ids)
- # Concatenate completions along the batch dimension: shape (batch_size * num_generations, seq_len)
- gen_ids_cat = torch.cat(all_gen_ids, dim=0)
- # Compute (model) log probabilities by doing a forward pass with labels equal to input.
- outputs = self.model(gen_ids_cat, labels=gen_ids_cat)
- logits = outputs.logits # shape: (B, seq_len, vocab_size)
- log_probs = torch.log_softmax(logits, dim=-1)
- # Shift for computing logprob of each token (standard language-modeling trick)
- target_ids = gen_ids_cat[:, 1:]
- log_probs = log_probs[:, :-1, :]
- token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
- # Sum over tokens to get a per-sequence log probability
- seq_log_probs = token_log_probs.sum(dim=-1) # shape: (B,)
- # Compute reference model log probabilities (without gradient)
- with torch.no_grad():
- ref_outputs = self.ref_model(gen_ids_cat, labels=gen_ids_cat)
- ref_logits = ref_outputs.logits
- ref_log_probs_full = torch.log_softmax(ref_logits, dim=-1)
- ref_log_probs_full = ref_log_probs_full[:, :-1, :]
- ref_token_log_probs = ref_log_probs_full.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
- ref_seq_log_probs = ref_token_log_probs.sum(dim=-1) # shape: (B,)
- # Compute KL divergence term per generated sequence using the formula:
- # KL ≈ exp(ref - model) - (ref - model) - 1
- kl = torch.exp(ref_seq_log_probs - seq_log_probs) - (ref_seq_log_probs - seq_log_probs) - 1
- # Decode the generated completions (list of strings)
- decoded_completions = self.tokenizer.batch_decode(gen_ids_cat, skip_special_tokens=True)
- # Reshape to (batch_size, num_generations)
- decoded_completions = [
- decoded_completions[i * self.num_generations : (i + 1) * self.num_generations]
- for i in range(batch_size)
- ]
- # Compute a simple reward for each prompt by averaging over its generations.
- rewards = []
- for i, gen_list in enumerate(decoded_completions):
- r_total = 0.0
- for gen_text in gen_list:
- # Use the simple accuracy reward (can be replaced with a more advanced function)
- r_total += accuracy_reward_simple(gen_text, solutions[i])
- rewards.append(r_total / self.num_generations)
- rewards = torch.tensor(rewards, device=self.device)
- # Compute advantage: subtract the baseline (mean reward over batch)
- baseline = rewards.mean()
- advantages = rewards - baseline
- # Replicate each advantage for each generation so that its shape matches seq_log_probs.
- advantages_rep = advantages.repeat_interleave(self.num_generations)
- # The GRPO loss is defined here as:
- # loss = -mean( advantage * (logprob) ) + beta * mean(KL)
- loss = - (advantages_rep * seq_log_probs).mean() + self.beta * kl.mean()
- return loss, decoded_completions, rewards
- def train(self, num_epochs, batch_size):
- """
- Runs the training loop over the train_dataset.
- """
- dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
- self.model.train() # set model to training mode
- for epoch in range(num_epochs):
- epoch_loss = 0.0
- for batch in dataloader:
- # Each batch is expected to be a dict with keys "prompt" and "solution"
- prompts = batch["prompt"]
- solutions = batch["solution"]
- self.optimizer.zero_grad()
- loss, decoded_completions, rewards = self.compute_loss(prompts, solutions)
- loss.backward()
- self.optimizer.step()
- epoch_loss += loss.item()
- avg_loss = epoch_loss / len(dataloader)
- print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")
- print("Training complete.")
- # ================================
- # MAIN SCRIPT
- # ================================
- def main():
- # 1. Load and format the dataset.
- print("Loading dataset...")
- raw_dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
- # Map each example to a conversation (adds "prompt" and "solution" keys).
- dataset = raw_dataset.map(make_conversation)
- # (Optionally, remove any unneeded columns.)
- dataset = dataset.remove_columns([col for col in raw_dataset.column_names if col not in ["prompt", "solution"]])
- print(f"Loaded {len(dataset)} training examples.")
- # 2. Load model and tokenizer.
- print("Loading model and tokenizer...")
- model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME)
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
- # (Make sure the tokenizer has a pad token.)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- model.to(DEVICE)
- # Create a reference model (a copy in evaluation mode; no gradients needed)
- ref_model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME)
- ref_model.to(DEVICE)
- ref_model.eval()
- # 3. Prepare the optimizer.
- optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
- # 4. Instantiate the simplified GRPO trainer.
- trainer = SimpleGRPOTrainer(
- model=model,
- ref_model=ref_model,
- tokenizer=tokenizer,
- train_dataset=dataset,
- optimizer=optimizer,
- device=DEVICE,
- num_generations=NUM_GENERATIONS,
- max_prompt_length=MAX_PROMPT_LENGTH,
- max_completion_length=MAX_COMPLETION_LENGTH,
- beta=BETA,
- gen_temperature=GEN_TEMPERATURE,
- )
- # 5. Run training.
- print("Starting training...")
- trainer.train(num_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE)
- # 6. Save the final model.
- print(f"Saving model to {OUTPUT_DIR} ...")
- os.makedirs(OUTPUT_DIR, exist_ok=True)
- model.save_pretrained(OUTPUT_DIR)
- tokenizer.save_pretrained(OUTPUT_DIR)
- print("Model saved.")
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement