Guest User

ModernBERT fine tuning

a guest
Apr 29th, 2025
292
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 20.64 KB | Source Code | 0 0
  1. import argparse
  2. import json
  3. import logging
  4. import os
  5. import random
  6. import time
  7. from typing import List, Tuple
  8.  
  9. import numpy as np
  10. import torch
  11. from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
  12. from sklearn.model_selection import train_test_split
  13. from torch.optim import AdamW
  14. from torch.utils.data import DataLoader, Dataset
  15. from tqdm import tqdm
  16. from transformers import (
  17.     AutoModelForSequenceClassification,
  18.     AutoTokenizer,
  19.     get_linear_schedule_with_warmup,
  20. )
  21.  
  22. logging.basicConfig(
  23.     level=logging.INFO,
  24.     format="%(asctime)s - %(levelname)s - %(message)s",
  25. )
  26. logger = logging.getLogger(__name__)
  27.  
  28. # Ensure deterministic behavior
  29. RANDOM_SEED = 42
  30. random.seed(RANDOM_SEED)
  31. np.random.seed(RANDOM_SEED)
  32. torch.manual_seed(RANDOM_SEED)
  33.  
  34. class JSONLDataset(Dataset):
  35.     """Dataset for loading and processing JSONL data."""
  36.  
  37.     def __init__(self, texts, labels, tokenizer, max_length=128):
  38.         """
  39.        Initialize dataset with texts and corresponding labels.
  40.  
  41.        Args:
  42.            texts: List of text strings
  43.            labels: List of labels (EVERGREEN or TIME-SENSITIVE)
  44.            tokenizer: Pretrained tokenizer
  45.            max_length: Maximum sequence length for tokenization
  46.        """
  47.         self.texts = texts
  48.         self.labels = labels
  49.         self.tokenizer = tokenizer
  50.         self.max_length = max_length
  51.         self.label_map = {"EVERGREEN": 0, "TIME-SENSITIVE": 1}
  52.  
  53.     def __len__(self):
  54.         return len(self.texts)
  55.  
  56.     def __getitem__(self, idx):
  57.         text = self.texts[idx]
  58.         label = self.label_map[self.labels[idx]]
  59.  
  60.         encoding = self.tokenizer(
  61.             text,
  62.             truncation=True,
  63.             padding="max_length",
  64.             max_length=self.max_length,
  65.             return_tensors="pt",
  66.         )
  67.  
  68.         # Remove the batch dimension added by the tokenizer
  69.         return {
  70.             "input_ids": encoding["input_ids"].squeeze(),
  71.             "attention_mask": encoding["attention_mask"].squeeze(),
  72.             "labels": torch.tensor(label, dtype=torch.long),
  73.         }
  74.  
  75.  
  76. def load_jsonl_data(file_path: str) -> Tuple[List[str], List[str]]:
  77.     """
  78.    Load and parse a JSONL file containing text and label fields.
  79.  
  80.    Args:
  81.        file_path: Path to the JSONL file
  82.  
  83.    Returns:
  84.        Tuple of (texts, labels)
  85.    """
  86.     texts = []
  87.     labels = []
  88.  
  89.     try:
  90.         with open(file_path, "r", encoding="utf-8") as f:
  91.             for line in f:
  92.                 try:
  93.                     item = json.loads(line.strip())
  94.                     # Ensure the required fields exist
  95.                     if "text" not in item or "label" not in item:
  96.                         logger.warning(f"Skipping line due to missing fields: {line.strip()}")
  97.                         continue
  98.  
  99.                     # Validate the label
  100.                     if item["label"] not in ["EVERGREEN", "TIME-SENSITIVE"]:
  101.                         logger.warning(f"Skipping line due to invalid label: {line.strip()}")
  102.                         continue
  103.                     texts.append(item["text"])
  104.                     labels.append(item["label"])
  105.                 except json.JSONDecodeError:
  106.                     logger.warning(f"Skipping line due to JSON parsing error: {line.strip()}")
  107.     except FileNotFoundError:
  108.         logger.error(f"Data file not found: {file_path}")
  109.         raise
  110.     except Exception as e:
  111.         logger.error(f"Error loading data from {file_path}: {e}")
  112.         raise
  113.  
  114.     if not texts:
  115.          logger.error(f"No valid examples loaded from {file_path}. Please check the file format and content.")
  116.          # Optionally raise an error or exit
  117.          raise ValueError(f"No valid data loaded from {file_path}")
  118.  
  119.  
  120.     logger.info(f"Loaded {len(texts)} valid examples from {file_path}")
  121.     return texts, labels
  122.  
  123.  
  124. def train_model(
  125.         model,
  126.         train_dataloader,
  127.         val_dataloader,
  128.         device, # Added device parameter
  129.         epochs=4,
  130.         learning_rate=2e-5,
  131.         weight_decay=0.01,
  132.         warmup_proportion=0.1,
  133. ):
  134.     """
  135.    Train the model and evaluate on validation data.
  136.  
  137.    Args:
  138.        model: The transformer model
  139.        train_dataloader: DataLoader for training data
  140.        val_dataloader: DataLoader for validation data
  141.        device: The torch device (e.g., torch.device("cpu"))
  142.        epochs: Number of training epochs
  143.        learning_rate: Learning rate for optimizer
  144.        weight_decay: Weight decay for regularization
  145.        warmup_proportion: Proportion of training steps for LR warmup
  146.  
  147.    Returns:
  148.        Trained model (best state based on validation accuracy)
  149.    """
  150.     no_decay = ["bias", "LayerNorm.weight"]
  151.     optimizer_grouped_parameters = [
  152.         {
  153.             "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  154.             "weight_decay": weight_decay,
  155.         },
  156.         {
  157.             "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  158.             "weight_decay": 0.0,
  159.         },
  160.     ]
  161.  
  162.     optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
  163.  
  164.     # Calculate total training steps for scheduler
  165.     total_steps = len(train_dataloader) * epochs
  166.     warmup_steps = int(total_steps * warmup_proportion)
  167.     scheduler = get_linear_schedule_with_warmup(
  168.         optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
  169.     )
  170.  
  171.     logger.info(f"Starting training for {epochs} epochs on device: {device}")
  172.     best_val_accuracy = 0.0
  173.     best_model_state = None
  174.  
  175.     for epoch in range(epochs):
  176.         logger.info(f"Epoch {epoch + 1}/{epochs}")
  177.  
  178.         # Training phase
  179.         model.train()
  180.         train_loss = 0.0
  181.         progress_bar = tqdm(train_dataloader, desc="Training")
  182.  
  183.         for batch in progress_bar:
  184.             optimizer.zero_grad()
  185.  
  186.             # Move batch to the specified device
  187.             input_ids = batch["input_ids"].to(device)
  188.             attention_mask = batch["attention_mask"].to(device)
  189.             labels = batch["labels"].to(device)
  190.  
  191.             outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
  192.             loss = outputs.loss
  193.  
  194.             # Check if loss is valid (might be None or NaN in rare cases)
  195.             if loss is not None and not torch.isnan(loss):
  196.                 loss.backward()
  197.                 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
  198.                 optimizer.step()
  199.                 scheduler.step()
  200.                 train_loss += loss.item()
  201.                 progress_bar.set_postfix({"loss": loss.item()})
  202.             else:
  203.                 logger.warning("Skipping batch due to invalid loss value.")
  204.  
  205.  
  206.         avg_train_loss = train_loss / len(train_dataloader) if len(train_dataloader) > 0 else 0.0
  207.         logger.info(f"Average training loss: {avg_train_loss:.4f}")
  208.  
  209.         # Validation phase
  210.         val_accuracy, val_loss = evaluate_model(model, val_dataloader, device) # Pass device
  211.         logger.info(f"Validation accuracy: {val_accuracy:.4f}, loss: {val_loss:.4f}")
  212.  
  213.         # Save best model
  214.         if val_accuracy > best_val_accuracy:
  215.             best_val_accuracy = val_accuracy
  216.             # Ensure model state is moved to CPU before copying if it was on GPU
  217.             # (Not strictly necessary here as we're on CPU, but good practice)
  218.             best_model_state = {k: v.cpu() for k, v in model.state_dict().items()}
  219.             logger.info(f"New best model found with validation accuracy: {val_accuracy:.4f}")
  220.  
  221.     # Load best model state for final return
  222.     if best_model_state:
  223.         logger.info(f"Restoring best model with validation accuracy: {best_val_accuracy:.4f}")
  224.         model.load_state_dict(best_model_state)
  225.     else:
  226.         logger.warning("No best model state saved (validation accuracy did not improve). Returning model from last epoch.")
  227.  
  228.  
  229.     return model
  230.  
  231.  
  232. def evaluate_model(model, dataloader, device): # Added device parameter
  233.     """
  234.    Evaluate the model on a dataset.
  235.  
  236.    Args:
  237.        model: The transformer model
  238.        dataloader: DataLoader for evaluation data
  239.        device: The torch device (e.g., torch.device("cpu"))
  240.  
  241.    Returns:
  242.        Tuple of (accuracy, average_loss)
  243.    """
  244.     model.eval()
  245.     true_labels = []
  246.     pred_labels = []
  247.     total_loss = 0.0
  248.  
  249.     with torch.no_grad():
  250.         for batch in tqdm(dataloader, desc="Evaluating"):
  251.             input_ids = batch["input_ids"].to(device)
  252.             attention_mask = batch["attention_mask"].to(device)
  253.             labels = batch["labels"].to(device)
  254.  
  255.             outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
  256.             loss = outputs.loss
  257.             logits = outputs.logits
  258.  
  259.             if loss is not None:
  260.                  total_loss += loss.item()
  261.  
  262.             # Get predictions
  263.             preds = torch.argmax(logits, dim=1)
  264.  
  265.             # Move predictions and labels to CPU for sklearn metrics
  266.             true_labels.extend(labels.cpu().numpy())
  267.             pred_labels.extend(preds.cpu().numpy())
  268.  
  269.     if not true_labels:
  270.         logger.warning("Evaluation dataloader was empty.")
  271.         return 0.0, 0.0
  272.  
  273.     accuracy = accuracy_score(true_labels, pred_labels)
  274.     avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0.0
  275.  
  276.     return accuracy, avg_loss
  277.  
  278.  
  279. def test_model(model, test_dataloader, label_map, device): # Added device parameter
  280.     """
  281.    Test the model and print detailed metrics.
  282.  
  283.    Args:
  284.        model: The transformer model
  285.        test_dataloader: DataLoader for test data
  286.        label_map: Mapping from label names to indices
  287.        device: The torch device (e.g., torch.device("cpu"))
  288.  
  289.    Returns:
  290.        Test accuracy
  291.    """
  292.     model.eval()
  293.     true_labels = []
  294.     pred_labels = []
  295.  
  296.     reverse_label_map = {v: k for k, v in label_map.items()}
  297.  
  298.     with torch.no_grad():
  299.         for batch in tqdm(test_dataloader, desc="Testing"):
  300.             # Move batch to the specified device
  301.             input_ids = batch["input_ids"].to(device)
  302.             attention_mask = batch["attention_mask"].to(device)
  303.             # Labels are part of the batch from the dataset, move them too for comparison
  304.             labels = batch["labels"].to(device)
  305.  
  306.             # Don't pass labels to the model during inference
  307.             outputs = model(input_ids=input_ids, attention_mask=attention_mask)
  308.             logits = outputs.logits
  309.  
  310.             # Get predictions
  311.             preds = torch.argmax(logits, dim=1)
  312.  
  313.             # Move predictions and labels to CPU for sklearn metrics
  314.             true_labels.extend(labels.cpu().numpy())
  315.             pred_labels.extend(preds.cpu().numpy())
  316.  
  317.     if not true_labels: # Handle empty dataloader case
  318.         logger.error("Test dataloader was empty. Cannot calculate metrics.")
  319.         return 0.0
  320.  
  321.     accuracy = accuracy_score(true_labels, pred_labels)
  322.  
  323.     target_names = [reverse_label_map[i] for i in sorted(reverse_label_map.keys())]
  324.     try:
  325.         report = classification_report(true_labels, pred_labels, target_names=target_names, zero_division=0)
  326.     except ValueError as e:
  327.         logger.warning(f"Could not generate classification report: {e}")
  328.         # This can happen if only one class is present in predictions or true labels
  329.         # Log unique values to help debug
  330.         logger.warning(f"Unique true labels: {np.unique(true_labels)}")
  331.         logger.warning(f"Unique predicted labels: {np.unique(pred_labels)}")
  332.         report = "N/A"
  333.  
  334.     try:
  335.         conf_matrix = confusion_matrix(true_labels, pred_labels, labels=sorted(reverse_label_map.keys()))
  336.     except ValueError as e:
  337.         logger.warning(f"Could not generate confusion matrix: {e}")
  338.         conf_matrix = "N/A"
  339.  
  340.     logger.info(f"Test accuracy: {accuracy:.4f}")
  341.     logger.info(f"Classification report:\n{report}")
  342.     logger.info(f"Confusion matrix (Labels: {sorted(reverse_label_map.keys())}):")
  343.     logger.info(f"{conf_matrix}")
  344.  
  345.  
  346.     return accuracy
  347.  
  348. def main():
  349.     """Main function to run the transformer fine-tuning process."""
  350.     parser = argparse.ArgumentParser(description="Fine-tune a transformer model for text classification on CPU")
  351.  
  352.     parser.add_argument(
  353.         "--data",
  354.         type=str,
  355.         required=True,
  356.         help="Path to JSONL file with 'text' and 'label' fields ('EVERGREEN' or 'TIME-SENSITIVE')"
  357.     )
  358.     parser.add_argument(
  359.         "--model",
  360.         type=str,
  361.         default="answerdotai/ModernBERT-base",
  362.         help="Base model identifier from Hugging Face Hub or local path (default: ModernBERT-base)"
  363.     )
  364.     parser.add_argument(
  365.         "--output_dir",
  366.         type=str,
  367.         default=None,
  368.         help="Directory to save the fine-tuned model and results"
  369.     )
  370.     parser.add_argument(
  371.         "--batch_size",
  372.         type=int,
  373.         default=4,
  374.         help="Batch size for training and evaluation (default: 4)"
  375.     )
  376.     parser.add_argument(
  377.         "--epochs",
  378.         type=int,
  379.         default=3,
  380.         help="Number of training epochs (default: 3)"
  381.     )
  382.     parser.add_argument(
  383.         "--learning_rate",
  384.         type=float,
  385.         default=2e-5,
  386.         help="Learning rate for AdamW optimizer (default: 2e-5)"
  387.     )
  388.     parser.add_argument(
  389.         "--max_length",
  390.         type=int,
  391.         default=2048,
  392.         help="Maximum sequence length for tokenizer (default: 2048)"
  393.     )
  394.     parser.add_argument(
  395.         "--test_size",
  396.         type=float,
  397.         default=0.15,
  398.         help="Proportion of data to use for testing (default: 0.15)"
  399.     )
  400.     parser.add_argument(
  401.         "--val_size",
  402.         type=float,
  403.         default=0.15,
  404.         help="Proportion of *training* data to use for validation (default: 0.15)"
  405.     )
  406.  
  407.     args = parser.parse_args()
  408.  
  409.     # Extract the JSONL filename without path and extension
  410.     data_filename = os.path.splitext(os.path.basename(args.data))[0]
  411.  
  412.     # Extract the model name (last part after the last slash or the whole string if no slash)
  413.     model_name = args.model.split("/")[-1]
  414.  
  415.     # Set the output directory using the JSONL filename and model name if not specified
  416.     if args.output_dir is None:
  417.         args.output_dir = f"./finetuned_{data_filename}_{model_name}"
  418.  
  419.     device = torch.device("cpu")
  420.     logger.info(f"Using device: {device}")
  421.  
  422.     os.makedirs(args.output_dir, exist_ok=True)
  423.     logger.info(f"Loading data from {args.data}")
  424.     try:
  425.         texts, labels = load_jsonl_data(args.data)
  426.     except (FileNotFoundError, ValueError) as e:
  427.          logger.error(f"Failed to load data: {e}")
  428.          return
  429.  
  430.     # Split data into train+val and test sets
  431.     logger.info(f"Splitting data: Test size={args.test_size}, Validation size={args.val_size} (of training set)")
  432.     train_texts, test_texts, train_labels, test_labels = train_test_split(
  433.         texts, labels, test_size=args.test_size, random_state=RANDOM_SEED, stratify=labels
  434.     )
  435.  
  436.     # Split training data further to create a validation set
  437.     # Ensure val_size is not too large, preventing tiny training sets
  438.     if len(train_texts) < 2 or args.val_size >= 1.0 or args.val_size <= 0.0:
  439.          logger.warning("Validation split size invalid or training set too small. Skipping validation split.")
  440.          val_texts, val_labels = [], [] # No validation set
  441.     else:
  442.          # Calculate validation size relative to the current train set
  443.          relative_val_size = args.val_size # Keep interpretation simple: proportion of current train set
  444.  
  445.          # Handle cases where the training set is very small after the test split
  446.          if int(len(train_texts) * relative_val_size) < 1:
  447.               logger.warning(f"Calculated validation set size is less than 1 example. Adjusting validation split.")
  448.               # Decide on a strategy: either skip validation or take at least one sample
  449.               if len(train_texts) > 1:
  450.                    relative_val_size = 1 / len(train_texts) # Take one sample for validation
  451.                    logger.info(f"Using 1 example for validation.")
  452.               else:
  453.                    logger.warning("Training set has only 1 example after test split. Cannot create validation set.")
  454.                    val_texts, val_labels = [], []
  455.                    relative_val_size = 0 # Ensure split doesn't happen
  456.  
  457.          if relative_val_size > 0:
  458.               train_texts, val_texts, train_labels, val_labels = train_test_split(
  459.                   train_texts, train_labels, test_size=relative_val_size, random_state=RANDOM_SEED, stratify=train_labels
  460.               )
  461.  
  462.  
  463.     logger.info(f"Train set size: {len(train_texts)}")
  464.     logger.info(f"Validation set size: {len(val_texts)}")
  465.     logger.info(f"Test set size: {len(test_texts)}")
  466.  
  467.     if len(train_texts) == 0:
  468.          logger.error("Training set is empty after splitting. Cannot proceed.")
  469.          return
  470.  
  471.  
  472.     # Load tokenizer and model
  473.     try:
  474.         logger.info(f"Loading tokenizer for model: {args.model}")
  475.         tokenizer = AutoTokenizer.from_pretrained(args.model)
  476.  
  477.         logger.info(f"Loading model: {args.model}")
  478.         model = AutoModelForSequenceClassification.from_pretrained(
  479.             args.model, num_labels=2 # Binary classification: EVERGREEN (0), TIME-SENSITIVE (1)
  480.         )
  481.  
  482.         model.to(device)
  483.         logger.info(f"Model moved to device: {device}")
  484.  
  485.     except OSError as e:
  486.         logger.error(f"Could not load model or tokenizer '{args.model}'. "
  487.                      f"Ensure it's a valid Hugging Face model identifier or local path. Error: {e}")
  488.         return
  489.     except Exception as e:
  490.         logger.error(f"An unexpected error occurred loading model/tokenizer: {e}")
  491.         return
  492.  
  493.     train_dataset = JSONLDataset(train_texts, train_labels, tokenizer, max_length=args.max_length)
  494.     val_dataset = JSONLDataset(val_texts, val_labels, tokenizer, max_length=args.max_length) if val_texts else None
  495.     test_dataset = JSONLDataset(test_texts, test_labels, tokenizer, max_length=args.max_length)
  496.  
  497.     train_dataloader = DataLoader(
  498.         train_dataset, batch_size=args.batch_size, shuffle=True # num_workers= can be added
  499.     )
  500.     val_dataloader = DataLoader(
  501.         val_dataset, batch_size=args.batch_size
  502.     ) if val_dataset else None
  503.     test_dataloader = DataLoader(
  504.         test_dataset, batch_size=args.batch_size
  505.     )
  506.  
  507.     if not val_dataloader:
  508.          logger.warning("No validation set/dataloader created. Training will proceed without validation pauses or best model saving based on validation.")
  509.  
  510.     logger.info("Starting model training...")
  511.     start_time = time.time()
  512.  
  513.     model = train_model(
  514.         model,
  515.         train_dataloader,
  516.         val_dataloader,
  517.         device,
  518.         epochs=args.epochs,
  519.         learning_rate=args.learning_rate
  520.     )
  521.     training_time = time.time() - start_time
  522.     logger.info(f"Training completed in {training_time:.2f} seconds")
  523.  
  524.     logger.info("Evaluating model on the test set...")
  525.     if len(test_dataloader) > 0:
  526.         test_accuracy = test_model(model, test_dataloader, train_dataset.label_map, device)
  527.     else:
  528.         logger.warning("Test dataloader is empty. Skipping testing.")
  529.         test_accuracy = 0.0
  530.  
  531.  
  532.     # Save the final model (which should be the best one if validation occurred)
  533.     logger.info(f"Saving final model and tokenizer to {args.output_dir}")
  534.     try:
  535.         model.save_pretrained(args.output_dir)
  536.         tokenizer.save_pretrained(args.output_dir)
  537.     except Exception as e:
  538.          logger.error(f"Error saving model/tokenizer: {e}")
  539.  
  540.  
  541.     # Save configuration and results
  542.     results = {
  543.         "model_base": args.model,
  544.         "final_test_accuracy": float(test_accuracy),
  545.         "training_time_seconds": round(training_time, 2),
  546.         "max_seq_length": args.max_length,
  547.         "batch_size": args.batch_size,
  548.         "epochs_trained": args.epochs,
  549.         "learning_rate": args.learning_rate,
  550.         "random_seed": RANDOM_SEED,
  551.         "num_train_examples": len(train_texts),
  552.         "num_val_examples": len(val_texts) if val_texts else 0,
  553.         "num_test_examples": len(test_texts),
  554.         "data_source": args.data,
  555.         "output_directory": args.output_dir,
  556.     }
  557.  
  558.     results_path = os.path.join(args.output_dir, "training_results.json")
  559.     try:
  560.         with open(results_path, "w", encoding="utf-8") as f:
  561.             json.dump(results, f, indent=2)
  562.         logger.info(f"Results saved to {results_path}")
  563.     except Exception as e:
  564.         logger.error(f"Error saving results JSON: {e}")
  565.  
  566.  
  567.     logger.info(f"Script finished. Final test accuracy: {test_accuracy:.4f}")
  568.  
  569.  
  570. if __name__ == "__main__":
  571.     main()
  572.  
Advertisement
Add Comment
Please, Sign In to add comment