Advertisement
Guest User

Train.py

a guest
May 7th, 2025
13
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.17 KB | None | 0 0
  1. import wandb
  2. import os
  3. import json
  4. import torch
  5. from accelerate import Accelerator
  6. from data.dataset import create_splits
  7. from model.model_loader import load_gen_model_and_processor, load_mini_gen_model_and_processor
  8. from configs.config import TrainingConfig, get_sft_configs
  9. from trl import SFTTrainer
  10. from qwen_vl_utils import process_vision_info
  11. from transformers import Qwen2_5_VLProcessor, AutoProcessor, TrainerCallback
  12. from trl import SFTConfig
  13. from peft import get_peft_model, LoraConfig
  14. from utils import clear_memory
  15. config = TrainingConfig()
  16.  
  17. class ClearCudaCacheCallback(TrainerCallback):
  18.     """After every optimization step, free any leftover cached memory."""
  19.     def on_step_end(self, args, state, control, **kwargs):
  20.         torch.cuda.empty_cache()
  21.         return control
  22.  
  23. def main():
  24.  
  25.     clear_memory()
  26.  
  27.     accelerator = Accelerator()
  28.  
  29.     training_args = SFTConfig(output_dir=config.output_dir,
  30.                                run_name=config.wandb_run_name,
  31.                                num_train_epochs=config.num_train_epochs,
  32.                                per_device_train_batch_size=1,  
  33.                                per_device_eval_batch_size=1,  
  34.                                gradient_accumulation_steps=8,
  35.                                gradient_checkpointing=True,
  36.                                learning_rate=config.lr,
  37.                                lr_scheduler_type="constant",
  38.                                logging_steps=10,
  39.                                eval_steps=10,
  40.                                eval_strategy="steps",
  41.                                save_strategy="steps",
  42.                                save_steps=20,
  43.                                metric_for_best_model="eval_loss",
  44.                                greater_is_better=False,
  45.                                load_best_model_at_end=True,
  46.                                fp16=True,
  47.                                bf16 = False,                      
  48.                                max_grad_norm=config.max_grad_norm,
  49.                                warmup_ratio=config.warmup_ratio,
  50.                                push_to_hub=False,
  51.                                report_to="wandb",
  52.                                gradient_checkpointing_kwargs={"use_reentrant": False},
  53.                                dataset_kwargs={"skip_prepare_dataset": True},
  54.                                deepspeed="configs/ds_config.json",
  55.                                max_seq_length=1024)  
  56.  
  57.     wandb.init(
  58.         project=config.wandb_project,
  59.         name=config.wandb_run_name,
  60.         config=config
  61.     )
  62.     model, processor = load_gen_model_and_processor(config)
  63.     model.config.use_cache = False
  64.  
  65.     # collects data from the dataset and prepares labels (predictors) for the model to
  66.     # compute loss over the assistant's response only
  67.  
  68.     def collate_fn(samples):
  69.         """each example is a dictionary of system, user, labels and image inputs like
  70.        [
  71.        {'role': 'system', 'content': [...]},
  72.        {'role': 'user',   'content': [...]},
  73.        {'role': 'assistant','content': [...]}
  74.            ]"""
  75.         prompts = [
  76.             processor.apply_chat_template(sample, tokenize=False) for sample in samples
  77.         ]
  78.         # process vision inputs (returns tuple, so get the image tensor)
  79.         # Get processor's target size
  80.         target_size = (
  81.             processor.image_processor.size
  82.         )  # should be a dict with "height" and "width"
  83.         image_inputs = []
  84.         for sample in samples:
  85.             image = process_vision_info(sample)[0]
  86.             if isinstance(image, list):
  87.                 if len(image) == 1:
  88.                     image = image[0]
  89.                 else:
  90.                     raise ValueError(
  91.                         f"Expected a single image, got a list of length {len(image)}"
  92.                     )
  93.             # Resize image to model's expected input size
  94.             if hasattr(image, "resize"):
  95.                 image = image.resize(
  96.                     (target_size["shortest_edge"], target_size["shortest_edge"])
  97.                 )
  98.             else:
  99.                 raise TypeError(f"Unsupported image type: {type(image)}")
  100.             image_inputs.append(image)
  101.             # Tokenize and encode batch
  102.             batch = processor(
  103.                 text=prompts, images=image_inputs, return_tensors="pt", padding="max_length", max_length=1024, truncation=True
  104.             )
  105.             labels = batch["input_ids"].clone()
  106.             labels[labels == processor.tokenizer.pad_token_id] = -100
  107.             # qwen-specific image tokens
  108.             if isinstance(processor, Qwen2_5_VLProcessor):
  109.                 image_tokens = [151652, 151653, 151655]
  110.             else:
  111.                 image_tokens = [
  112.                     processor.tokenizer.convert_tokens_to_ids(processor.image_token)
  113.                 ]
  114.             for image_token_id in image_tokens:
  115.                 labels[labels == image_token_id] = -100
  116.             batch["labels"] = labels
  117.             return batch
  118.  
  119.     train_dataset, eval_dataset, test_dataset = create_splits(config.json_path, config.image_dir, config.train, config.val, config.test)
  120.    
  121.     output_test_dir = os.path.join(config.output_dir, "test")
  122.     os.makedirs(output_test_dir, exist_ok=True)
  123.  
  124.     test_data = list(test_dataset)
  125.  
  126.     test_file_path = os.path.join(output_test_dir, "test_data.json")
  127.     with open(test_file_path, "w") as f:
  128.         json.dump(test_data, f, indent=2)
  129.  
  130.     peft_config = LoraConfig(
  131.             lora_alpha=config.lora_alpha,
  132.             lora_dropout=config.lora_dropout,
  133.             r=config.lora_r,
  134.             bias="none",
  135.             target_modules=["q_proj", "v_proj"],
  136.             task_type="CAUSAL_LM",
  137.         )
  138.  
  139.     trainer = SFTTrainer(
  140.         model=model,
  141.         args=training_args,
  142.         train_dataset=train_dataset,
  143.         eval_dataset=eval_dataset,
  144.         data_collator=collate_fn,  
  145.         peft_config=peft_config,
  146.         processing_class=processor.tokenizer,
  147.         callbacks=[ClearCudaCacheCallback]
  148.     )
  149.  
  150.     trainer.train()
  151.  
  152.     trainer.save_model(config.output_dir)
  153.  
  154.  
  155.  
  156. if __name__ == "__main__":
  157.     main()
  158.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement