Advertisement
Guest User

Qwen2VL-GRPO-o3mini

a guest
Feb 11th, 2025
18
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.11 KB | Source Code | 0 0
  1. #!/usr/bin/env python
  2. """
  3. A simplified GRPO training script for Qwen2-VL.
  4. This script:
  5.  - Loads a dataset (assumed to have a "problem" and "solution" field).
  6.  - Formats each example as a conversation prompt.
  7.  - Loads the Qwen2-VL model and its tokenizer.
  8.  - Creates a “reference” model (a copy of the loaded model) used for KL regularization.
  9.  - Runs a custom GRPO training loop:
  10.      * For each batch, generates multiple completions.
  11.      * Computes per-token log probabilities from both the current and reference model.
  12.      * Computes a (very simple) reward per example.
  13.      * Computes a GRPO loss as a (negative) advantage–weighted log probability plus a KL penalty.
  14.  - Saves the final model.
  15.  
  16. IMPORTANT:
  17.  – This script is intentionally simplified.
  18.  – Many details from the original repo (e.g. image support, multi‐reward functions, accelerator handling) are removed.
  19.  – If you need more “real” math verification or additional reward functions, please update the reward functions accordingly.
  20.  
  21. Adjust the placeholder configuration values near the top as needed.
  22. """
  23.  
  24. import os
  25. import torch
  26. from torch.utils.data import DataLoader
  27. from torch.optim import AdamW
  28. from datasets import load_dataset
  29. from transformers import (
  30.     Qwen2VLForConditionalGeneration,
  31.     AutoTokenizer,
  32. )
  33.  
  34. # ================================
  35. # CONFIGURATION (EDIT THESE VALUES)
  36. # ================================
  37. # Dataset & model
  38. DATASET_NAME = "your_dataset_name"         # e.g., "HuggingFaceH4/YourDataset"
  39. DATASET_CONFIG = "default"                 # if needed; otherwise, leave as "default"
  40. TRAIN_SPLIT = "train"                      # training split name
  41. # (Assume the dataset examples have at least the keys "problem" and "solution")
  42.  
  43. MODEL_NAME = "your-model-org/Qwen2-VL"       # placeholder model name (must be a Qwen2-VL branch)
  44. OUTPUT_DIR = "./grpo_trained_model"        # where to save the model
  45.  
  46. # Training hyperparameters
  47. NUM_EPOCHS = 3
  48. BATCH_SIZE = 2                           # adjust for your GPU memory
  49. LEARNING_RATE = 5e-5
  50. DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
  51.  
  52. # GRPO-specific hyperparameters
  53. MAX_PROMPT_LENGTH = 512                  # maximum tokens for the prompt
  54. MAX_COMPLETION_LENGTH = 128              # maximum new tokens to generate per prompt
  55. NUM_GENERATIONS = 3                      # number of generations per prompt for GRPO
  56. BETA = 0.1                             # KL regularization coefficient
  57. GEN_TEMPERATURE = 1.0                    # temperature for generation
  58.  
  59. # A simple system prompt (used to “prime” the conversation)
  60. SYSTEM_PROMPT = (
  61.     "A conversation between User and Assistant. The assistant is expected to think through the problem "
  62.     "and then answer. The reasoning is enclosed within <think>...</think> and the final answer within <answer>...</answer>."
  63. )
  64.  
  65. # ================================
  66. # UTILITY FUNCTIONS & CLASSES
  67. # ================================
  68.  
  69. def accuracy_reward_simple(completion: str, solution: str) -> float:
  70.     """
  71.    A very simple reward function: returns 1.0 if the (stripped) solution text appears
  72.    in the generated completion; else returns 0.0.
  73.    """
  74.     return 1.0 if solution.strip() in completion.strip() else 0.0
  75.  
  76. def make_conversation(example):
  77.     """
  78.    Formats a dataset example into a conversation.
  79.    Assumes each example has a "problem" and a "solution".
  80.    """
  81.     # Here we simply prepend a system prompt to the problem.
  82.     prompt = f"{SYSTEM_PROMPT}\nUser: {example['problem']}\nAssistant:"
  83.     # The solution (used later for reward computation) remains untouched.
  84.     return {"prompt": prompt, "solution": example["solution"]}
  85.  
  86. class SimpleGRPOTrainer:
  87.     """
  88.    A simplified GRPO trainer that:
  89.      - For each batch, generates multiple completions per prompt.
  90.      - Computes log probabilities (via a forward pass) on both the current model and a reference model.
  91.      - Computes a simple reward (via an accuracy check).
  92.      - Computes a GRPO loss (negative advantage * logprob plus a KL penalty).
  93.    """
  94.     def __init__(self, model, ref_model, tokenizer, train_dataset, optimizer, device,
  95.                  num_generations, max_prompt_length, max_completion_length, beta, gen_temperature):
  96.         self.model = model
  97.         self.ref_model = ref_model
  98.         self.tokenizer = tokenizer
  99.         self.train_dataset = train_dataset
  100.         self.optimizer = optimizer
  101.         self.device = device
  102.         self.num_generations = num_generations
  103.         self.max_prompt_length = max_prompt_length
  104.         self.max_completion_length = max_completion_length
  105.         self.beta = beta
  106.         self.gen_temperature = gen_temperature
  107.  
  108.     def compute_loss(self, prompts, solutions):
  109.         """
  110.        For a list of prompt strings and corresponding solution strings:
  111.          1. Tokenize the prompts.
  112.          2. For each prompt, generate NUM_GENERATIONS completions.
  113.          3. Compute per-sequence log probabilities for each completion under both the current model and the reference.
  114.          4. Compute a simple reward per example (averaging over generations).
  115.          5. Compute an “advantage” as (reward - baseline) and form a loss.
  116.        """
  117.         # Tokenize prompts (truncate/pad to max_prompt_length)
  118.         inputs = self.tokenizer(
  119.             prompts,
  120.             return_tensors="pt",
  121.             padding=True,
  122.             truncation=True,
  123.             max_length=self.max_prompt_length,
  124.         ).to(self.device)
  125.         batch_size = inputs.input_ids.size(0)
  126.  
  127.         # Generate NUM_GENERATIONS completions per prompt.
  128.         all_gen_ids = []
  129.         for _ in range(self.num_generations):
  130.             # Note: we use do_sample=True to allow stochastic generation.
  131.             gen_ids = self.model.generate(
  132.                 **inputs,
  133.                 max_new_tokens=self.max_completion_length,
  134.                 do_sample=True,
  135.                 temperature=self.gen_temperature,
  136.             )
  137.             all_gen_ids.append(gen_ids)
  138.         # Concatenate completions along the batch dimension: shape (batch_size * num_generations, seq_len)
  139.         gen_ids_cat = torch.cat(all_gen_ids, dim=0)
  140.  
  141.         # Compute (model) log probabilities by doing a forward pass with labels equal to input.
  142.         outputs = self.model(gen_ids_cat, labels=gen_ids_cat)
  143.         logits = outputs.logits  # shape: (B, seq_len, vocab_size)
  144.         log_probs = torch.log_softmax(logits, dim=-1)
  145.         # Shift for computing logprob of each token (standard language-modeling trick)
  146.         target_ids = gen_ids_cat[:, 1:]
  147.         log_probs = log_probs[:, :-1, :]
  148.         token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
  149.         # Sum over tokens to get a per-sequence log probability
  150.         seq_log_probs = token_log_probs.sum(dim=-1)  # shape: (B,)
  151.  
  152.         # Compute reference model log probabilities (without gradient)
  153.         with torch.no_grad():
  154.             ref_outputs = self.ref_model(gen_ids_cat, labels=gen_ids_cat)
  155.             ref_logits = ref_outputs.logits
  156.             ref_log_probs_full = torch.log_softmax(ref_logits, dim=-1)
  157.             ref_log_probs_full = ref_log_probs_full[:, :-1, :]
  158.             ref_token_log_probs = ref_log_probs_full.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
  159.             ref_seq_log_probs = ref_token_log_probs.sum(dim=-1)  # shape: (B,)
  160.  
  161.         # Compute KL divergence term per generated sequence using the formula:
  162.         #   KL ≈ exp(ref - model) - (ref - model) - 1
  163.         kl = torch.exp(ref_seq_log_probs - seq_log_probs) - (ref_seq_log_probs - seq_log_probs) - 1
  164.  
  165.         # Decode the generated completions (list of strings)
  166.         decoded_completions = self.tokenizer.batch_decode(gen_ids_cat, skip_special_tokens=True)
  167.         # Reshape to (batch_size, num_generations)
  168.         decoded_completions = [
  169.             decoded_completions[i * self.num_generations : (i + 1) * self.num_generations]
  170.             for i in range(batch_size)
  171.         ]
  172.  
  173.         # Compute a simple reward for each prompt by averaging over its generations.
  174.         rewards = []
  175.         for i, gen_list in enumerate(decoded_completions):
  176.             r_total = 0.0
  177.             for gen_text in gen_list:
  178.                 # Use the simple accuracy reward (can be replaced with a more advanced function)
  179.                 r_total += accuracy_reward_simple(gen_text, solutions[i])
  180.             rewards.append(r_total / self.num_generations)
  181.         rewards = torch.tensor(rewards, device=self.device)
  182.  
  183.         # Compute advantage: subtract the baseline (mean reward over batch)
  184.         baseline = rewards.mean()
  185.         advantages = rewards - baseline
  186.         # Replicate each advantage for each generation so that its shape matches seq_log_probs.
  187.         advantages_rep = advantages.repeat_interleave(self.num_generations)
  188.  
  189.         # The GRPO loss is defined here as:
  190.         #    loss = -mean( advantage * (logprob) ) + beta * mean(KL)
  191.         loss = - (advantages_rep * seq_log_probs).mean() + self.beta * kl.mean()
  192.  
  193.         return loss, decoded_completions, rewards
  194.  
  195.     def train(self, num_epochs, batch_size):
  196.         """
  197.        Runs the training loop over the train_dataset.
  198.        """
  199.         dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
  200.         self.model.train()  # set model to training mode
  201.  
  202.         for epoch in range(num_epochs):
  203.             epoch_loss = 0.0
  204.             for batch in dataloader:
  205.                 # Each batch is expected to be a dict with keys "prompt" and "solution"
  206.                 prompts = batch["prompt"]
  207.                 solutions = batch["solution"]
  208.  
  209.                 self.optimizer.zero_grad()
  210.                 loss, decoded_completions, rewards = self.compute_loss(prompts, solutions)
  211.                 loss.backward()
  212.                 self.optimizer.step()
  213.  
  214.                 epoch_loss += loss.item()
  215.             avg_loss = epoch_loss / len(dataloader)
  216.             print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")
  217.         print("Training complete.")
  218.  
  219.  
  220. # ================================
  221. # MAIN SCRIPT
  222. # ================================
  223. def main():
  224.     # 1. Load and format the dataset.
  225.     print("Loading dataset...")
  226.     raw_dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
  227.     # Map each example to a conversation (adds "prompt" and "solution" keys).
  228.     dataset = raw_dataset.map(make_conversation)
  229.     # (Optionally, remove any unneeded columns.)
  230.     dataset = dataset.remove_columns([col for col in raw_dataset.column_names if col not in ["prompt", "solution"]])
  231.     print(f"Loaded {len(dataset)} training examples.")
  232.  
  233.     # 2. Load model and tokenizer.
  234.     print("Loading model and tokenizer...")
  235.     model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME)
  236.     tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
  237.     # (Make sure the tokenizer has a pad token.)
  238.     if tokenizer.pad_token is None:
  239.         tokenizer.pad_token = tokenizer.eos_token
  240.  
  241.     model.to(DEVICE)
  242.     # Create a reference model (a copy in evaluation mode; no gradients needed)
  243.     ref_model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME)
  244.     ref_model.to(DEVICE)
  245.     ref_model.eval()
  246.  
  247.     # 3. Prepare the optimizer.
  248.     optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
  249.  
  250.     # 4. Instantiate the simplified GRPO trainer.
  251.     trainer = SimpleGRPOTrainer(
  252.         model=model,
  253.         ref_model=ref_model,
  254.         tokenizer=tokenizer,
  255.         train_dataset=dataset,
  256.         optimizer=optimizer,
  257.         device=DEVICE,
  258.         num_generations=NUM_GENERATIONS,
  259.         max_prompt_length=MAX_PROMPT_LENGTH,
  260.         max_completion_length=MAX_COMPLETION_LENGTH,
  261.         beta=BETA,
  262.         gen_temperature=GEN_TEMPERATURE,
  263.     )
  264.  
  265.     # 5. Run training.
  266.     print("Starting training...")
  267.     trainer.train(num_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE)
  268.  
  269.     # 6. Save the final model.
  270.     print(f"Saving model to {OUTPUT_DIR} ...")
  271.     os.makedirs(OUTPUT_DIR, exist_ok=True)
  272.     model.save_pretrained(OUTPUT_DIR)
  273.     tokenizer.save_pretrained(OUTPUT_DIR)
  274.     print("Model saved.")
  275.  
  276.  
  277. if __name__ == "__main__":
  278.     main()
  279.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement