Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import json
- import logging
- import os
- import random
- import time
- from typing import List, Tuple
- import numpy as np
- import torch
- from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
- from sklearn.model_selection import train_test_split
- from torch.optim import AdamW
- from torch.utils.data import DataLoader, Dataset
- from tqdm import tqdm
- from transformers import (
- AutoModelForSequenceClassification,
- AutoTokenizer,
- get_linear_schedule_with_warmup,
- )
- logging.basicConfig(
- level=logging.INFO,
- format="%(asctime)s - %(levelname)s - %(message)s",
- )
- logger = logging.getLogger(__name__)
- # Ensure deterministic behavior
- RANDOM_SEED = 42
- random.seed(RANDOM_SEED)
- np.random.seed(RANDOM_SEED)
- torch.manual_seed(RANDOM_SEED)
- class JSONLDataset(Dataset):
- """Dataset for loading and processing JSONL data."""
- def __init__(self, texts, labels, tokenizer, max_length=128):
- """
- Initialize dataset with texts and corresponding labels.
- Args:
- texts: List of text strings
- labels: List of labels (EVERGREEN or TIME-SENSITIVE)
- tokenizer: Pretrained tokenizer
- max_length: Maximum sequence length for tokenization
- """
- self.texts = texts
- self.labels = labels
- self.tokenizer = tokenizer
- self.max_length = max_length
- self.label_map = {"EVERGREEN": 0, "TIME-SENSITIVE": 1}
- def __len__(self):
- return len(self.texts)
- def __getitem__(self, idx):
- text = self.texts[idx]
- label = self.label_map[self.labels[idx]]
- encoding = self.tokenizer(
- text,
- truncation=True,
- padding="max_length",
- max_length=self.max_length,
- return_tensors="pt",
- )
- # Remove the batch dimension added by the tokenizer
- return {
- "input_ids": encoding["input_ids"].squeeze(),
- "attention_mask": encoding["attention_mask"].squeeze(),
- "labels": torch.tensor(label, dtype=torch.long),
- }
- def load_jsonl_data(file_path: str) -> Tuple[List[str], List[str]]:
- """
- Load and parse a JSONL file containing text and label fields.
- Args:
- file_path: Path to the JSONL file
- Returns:
- Tuple of (texts, labels)
- """
- texts = []
- labels = []
- try:
- with open(file_path, "r", encoding="utf-8") as f:
- for line in f:
- try:
- item = json.loads(line.strip())
- # Ensure the required fields exist
- if "text" not in item or "label" not in item:
- logger.warning(f"Skipping line due to missing fields: {line.strip()}")
- continue
- # Validate the label
- if item["label"] not in ["EVERGREEN", "TIME-SENSITIVE"]:
- logger.warning(f"Skipping line due to invalid label: {line.strip()}")
- continue
- texts.append(item["text"])
- labels.append(item["label"])
- except json.JSONDecodeError:
- logger.warning(f"Skipping line due to JSON parsing error: {line.strip()}")
- except FileNotFoundError:
- logger.error(f"Data file not found: {file_path}")
- raise
- except Exception as e:
- logger.error(f"Error loading data from {file_path}: {e}")
- raise
- if not texts:
- logger.error(f"No valid examples loaded from {file_path}. Please check the file format and content.")
- # Optionally raise an error or exit
- raise ValueError(f"No valid data loaded from {file_path}")
- logger.info(f"Loaded {len(texts)} valid examples from {file_path}")
- return texts, labels
- def train_model(
- model,
- train_dataloader,
- val_dataloader,
- device, # Added device parameter
- epochs=4,
- learning_rate=2e-5,
- weight_decay=0.01,
- warmup_proportion=0.1,
- ):
- """
- Train the model and evaluate on validation data.
- Args:
- model: The transformer model
- train_dataloader: DataLoader for training data
- val_dataloader: DataLoader for validation data
- device: The torch device (e.g., torch.device("cpu"))
- epochs: Number of training epochs
- learning_rate: Learning rate for optimizer
- weight_decay: Weight decay for regularization
- warmup_proportion: Proportion of training steps for LR warmup
- Returns:
- Trained model (best state based on validation accuracy)
- """
- no_decay = ["bias", "LayerNorm.weight"]
- optimizer_grouped_parameters = [
- {
- "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
- "weight_decay": weight_decay,
- },
- {
- "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
- "weight_decay": 0.0,
- },
- ]
- optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
- # Calculate total training steps for scheduler
- total_steps = len(train_dataloader) * epochs
- warmup_steps = int(total_steps * warmup_proportion)
- scheduler = get_linear_schedule_with_warmup(
- optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
- )
- logger.info(f"Starting training for {epochs} epochs on device: {device}")
- best_val_accuracy = 0.0
- best_model_state = None
- for epoch in range(epochs):
- logger.info(f"Epoch {epoch + 1}/{epochs}")
- # Training phase
- model.train()
- train_loss = 0.0
- progress_bar = tqdm(train_dataloader, desc="Training")
- for batch in progress_bar:
- optimizer.zero_grad()
- # Move batch to the specified device
- input_ids = batch["input_ids"].to(device)
- attention_mask = batch["attention_mask"].to(device)
- labels = batch["labels"].to(device)
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
- loss = outputs.loss
- # Check if loss is valid (might be None or NaN in rare cases)
- if loss is not None and not torch.isnan(loss):
- loss.backward()
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
- optimizer.step()
- scheduler.step()
- train_loss += loss.item()
- progress_bar.set_postfix({"loss": loss.item()})
- else:
- logger.warning("Skipping batch due to invalid loss value.")
- avg_train_loss = train_loss / len(train_dataloader) if len(train_dataloader) > 0 else 0.0
- logger.info(f"Average training loss: {avg_train_loss:.4f}")
- # Validation phase
- val_accuracy, val_loss = evaluate_model(model, val_dataloader, device) # Pass device
- logger.info(f"Validation accuracy: {val_accuracy:.4f}, loss: {val_loss:.4f}")
- # Save best model
- if val_accuracy > best_val_accuracy:
- best_val_accuracy = val_accuracy
- # Ensure model state is moved to CPU before copying if it was on GPU
- # (Not strictly necessary here as we're on CPU, but good practice)
- best_model_state = {k: v.cpu() for k, v in model.state_dict().items()}
- logger.info(f"New best model found with validation accuracy: {val_accuracy:.4f}")
- # Load best model state for final return
- if best_model_state:
- logger.info(f"Restoring best model with validation accuracy: {best_val_accuracy:.4f}")
- model.load_state_dict(best_model_state)
- else:
- logger.warning("No best model state saved (validation accuracy did not improve). Returning model from last epoch.")
- return model
- def evaluate_model(model, dataloader, device): # Added device parameter
- """
- Evaluate the model on a dataset.
- Args:
- model: The transformer model
- dataloader: DataLoader for evaluation data
- device: The torch device (e.g., torch.device("cpu"))
- Returns:
- Tuple of (accuracy, average_loss)
- """
- model.eval()
- true_labels = []
- pred_labels = []
- total_loss = 0.0
- with torch.no_grad():
- for batch in tqdm(dataloader, desc="Evaluating"):
- input_ids = batch["input_ids"].to(device)
- attention_mask = batch["attention_mask"].to(device)
- labels = batch["labels"].to(device)
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
- loss = outputs.loss
- logits = outputs.logits
- if loss is not None:
- total_loss += loss.item()
- # Get predictions
- preds = torch.argmax(logits, dim=1)
- # Move predictions and labels to CPU for sklearn metrics
- true_labels.extend(labels.cpu().numpy())
- pred_labels.extend(preds.cpu().numpy())
- if not true_labels:
- logger.warning("Evaluation dataloader was empty.")
- return 0.0, 0.0
- accuracy = accuracy_score(true_labels, pred_labels)
- avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0.0
- return accuracy, avg_loss
- def test_model(model, test_dataloader, label_map, device): # Added device parameter
- """
- Test the model and print detailed metrics.
- Args:
- model: The transformer model
- test_dataloader: DataLoader for test data
- label_map: Mapping from label names to indices
- device: The torch device (e.g., torch.device("cpu"))
- Returns:
- Test accuracy
- """
- model.eval()
- true_labels = []
- pred_labels = []
- reverse_label_map = {v: k for k, v in label_map.items()}
- with torch.no_grad():
- for batch in tqdm(test_dataloader, desc="Testing"):
- # Move batch to the specified device
- input_ids = batch["input_ids"].to(device)
- attention_mask = batch["attention_mask"].to(device)
- # Labels are part of the batch from the dataset, move them too for comparison
- labels = batch["labels"].to(device)
- # Don't pass labels to the model during inference
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
- logits = outputs.logits
- # Get predictions
- preds = torch.argmax(logits, dim=1)
- # Move predictions and labels to CPU for sklearn metrics
- true_labels.extend(labels.cpu().numpy())
- pred_labels.extend(preds.cpu().numpy())
- if not true_labels: # Handle empty dataloader case
- logger.error("Test dataloader was empty. Cannot calculate metrics.")
- return 0.0
- accuracy = accuracy_score(true_labels, pred_labels)
- target_names = [reverse_label_map[i] for i in sorted(reverse_label_map.keys())]
- try:
- report = classification_report(true_labels, pred_labels, target_names=target_names, zero_division=0)
- except ValueError as e:
- logger.warning(f"Could not generate classification report: {e}")
- # This can happen if only one class is present in predictions or true labels
- # Log unique values to help debug
- logger.warning(f"Unique true labels: {np.unique(true_labels)}")
- logger.warning(f"Unique predicted labels: {np.unique(pred_labels)}")
- report = "N/A"
- try:
- conf_matrix = confusion_matrix(true_labels, pred_labels, labels=sorted(reverse_label_map.keys()))
- except ValueError as e:
- logger.warning(f"Could not generate confusion matrix: {e}")
- conf_matrix = "N/A"
- logger.info(f"Test accuracy: {accuracy:.4f}")
- logger.info(f"Classification report:\n{report}")
- logger.info(f"Confusion matrix (Labels: {sorted(reverse_label_map.keys())}):")
- logger.info(f"{conf_matrix}")
- return accuracy
- def main():
- """Main function to run the transformer fine-tuning process."""
- parser = argparse.ArgumentParser(description="Fine-tune a transformer model for text classification on CPU")
- parser.add_argument(
- "--data",
- type=str,
- required=True,
- help="Path to JSONL file with 'text' and 'label' fields ('EVERGREEN' or 'TIME-SENSITIVE')"
- )
- parser.add_argument(
- "--model",
- type=str,
- default="answerdotai/ModernBERT-base",
- help="Base model identifier from Hugging Face Hub or local path (default: ModernBERT-base)"
- )
- parser.add_argument(
- "--output_dir",
- type=str,
- default=None,
- help="Directory to save the fine-tuned model and results"
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=4,
- help="Batch size for training and evaluation (default: 4)"
- )
- parser.add_argument(
- "--epochs",
- type=int,
- default=3,
- help="Number of training epochs (default: 3)"
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=2e-5,
- help="Learning rate for AdamW optimizer (default: 2e-5)"
- )
- parser.add_argument(
- "--max_length",
- type=int,
- default=2048,
- help="Maximum sequence length for tokenizer (default: 2048)"
- )
- parser.add_argument(
- "--test_size",
- type=float,
- default=0.15,
- help="Proportion of data to use for testing (default: 0.15)"
- )
- parser.add_argument(
- "--val_size",
- type=float,
- default=0.15,
- help="Proportion of *training* data to use for validation (default: 0.15)"
- )
- args = parser.parse_args()
- # Extract the JSONL filename without path and extension
- data_filename = os.path.splitext(os.path.basename(args.data))[0]
- # Extract the model name (last part after the last slash or the whole string if no slash)
- model_name = args.model.split("/")[-1]
- # Set the output directory using the JSONL filename and model name if not specified
- if args.output_dir is None:
- args.output_dir = f"./finetuned_{data_filename}_{model_name}"
- device = torch.device("cpu")
- logger.info(f"Using device: {device}")
- os.makedirs(args.output_dir, exist_ok=True)
- logger.info(f"Loading data from {args.data}")
- try:
- texts, labels = load_jsonl_data(args.data)
- except (FileNotFoundError, ValueError) as e:
- logger.error(f"Failed to load data: {e}")
- return
- # Split data into train+val and test sets
- logger.info(f"Splitting data: Test size={args.test_size}, Validation size={args.val_size} (of training set)")
- train_texts, test_texts, train_labels, test_labels = train_test_split(
- texts, labels, test_size=args.test_size, random_state=RANDOM_SEED, stratify=labels
- )
- # Split training data further to create a validation set
- # Ensure val_size is not too large, preventing tiny training sets
- if len(train_texts) < 2 or args.val_size >= 1.0 or args.val_size <= 0.0:
- logger.warning("Validation split size invalid or training set too small. Skipping validation split.")
- val_texts, val_labels = [], [] # No validation set
- else:
- # Calculate validation size relative to the current train set
- relative_val_size = args.val_size # Keep interpretation simple: proportion of current train set
- # Handle cases where the training set is very small after the test split
- if int(len(train_texts) * relative_val_size) < 1:
- logger.warning(f"Calculated validation set size is less than 1 example. Adjusting validation split.")
- # Decide on a strategy: either skip validation or take at least one sample
- if len(train_texts) > 1:
- relative_val_size = 1 / len(train_texts) # Take one sample for validation
- logger.info(f"Using 1 example for validation.")
- else:
- logger.warning("Training set has only 1 example after test split. Cannot create validation set.")
- val_texts, val_labels = [], []
- relative_val_size = 0 # Ensure split doesn't happen
- if relative_val_size > 0:
- train_texts, val_texts, train_labels, val_labels = train_test_split(
- train_texts, train_labels, test_size=relative_val_size, random_state=RANDOM_SEED, stratify=train_labels
- )
- logger.info(f"Train set size: {len(train_texts)}")
- logger.info(f"Validation set size: {len(val_texts)}")
- logger.info(f"Test set size: {len(test_texts)}")
- if len(train_texts) == 0:
- logger.error("Training set is empty after splitting. Cannot proceed.")
- return
- # Load tokenizer and model
- try:
- logger.info(f"Loading tokenizer for model: {args.model}")
- tokenizer = AutoTokenizer.from_pretrained(args.model)
- logger.info(f"Loading model: {args.model}")
- model = AutoModelForSequenceClassification.from_pretrained(
- args.model, num_labels=2 # Binary classification: EVERGREEN (0), TIME-SENSITIVE (1)
- )
- model.to(device)
- logger.info(f"Model moved to device: {device}")
- except OSError as e:
- logger.error(f"Could not load model or tokenizer '{args.model}'. "
- f"Ensure it's a valid Hugging Face model identifier or local path. Error: {e}")
- return
- except Exception as e:
- logger.error(f"An unexpected error occurred loading model/tokenizer: {e}")
- return
- train_dataset = JSONLDataset(train_texts, train_labels, tokenizer, max_length=args.max_length)
- val_dataset = JSONLDataset(val_texts, val_labels, tokenizer, max_length=args.max_length) if val_texts else None
- test_dataset = JSONLDataset(test_texts, test_labels, tokenizer, max_length=args.max_length)
- train_dataloader = DataLoader(
- train_dataset, batch_size=args.batch_size, shuffle=True # num_workers= can be added
- )
- val_dataloader = DataLoader(
- val_dataset, batch_size=args.batch_size
- ) if val_dataset else None
- test_dataloader = DataLoader(
- test_dataset, batch_size=args.batch_size
- )
- if not val_dataloader:
- logger.warning("No validation set/dataloader created. Training will proceed without validation pauses or best model saving based on validation.")
- logger.info("Starting model training...")
- start_time = time.time()
- model = train_model(
- model,
- train_dataloader,
- val_dataloader,
- device,
- epochs=args.epochs,
- learning_rate=args.learning_rate
- )
- training_time = time.time() - start_time
- logger.info(f"Training completed in {training_time:.2f} seconds")
- logger.info("Evaluating model on the test set...")
- if len(test_dataloader) > 0:
- test_accuracy = test_model(model, test_dataloader, train_dataset.label_map, device)
- else:
- logger.warning("Test dataloader is empty. Skipping testing.")
- test_accuracy = 0.0
- # Save the final model (which should be the best one if validation occurred)
- logger.info(f"Saving final model and tokenizer to {args.output_dir}")
- try:
- model.save_pretrained(args.output_dir)
- tokenizer.save_pretrained(args.output_dir)
- except Exception as e:
- logger.error(f"Error saving model/tokenizer: {e}")
- # Save configuration and results
- results = {
- "model_base": args.model,
- "final_test_accuracy": float(test_accuracy),
- "training_time_seconds": round(training_time, 2),
- "max_seq_length": args.max_length,
- "batch_size": args.batch_size,
- "epochs_trained": args.epochs,
- "learning_rate": args.learning_rate,
- "random_seed": RANDOM_SEED,
- "num_train_examples": len(train_texts),
- "num_val_examples": len(val_texts) if val_texts else 0,
- "num_test_examples": len(test_texts),
- "data_source": args.data,
- "output_directory": args.output_dir,
- }
- results_path = os.path.join(args.output_dir, "training_results.json")
- try:
- with open(results_path, "w", encoding="utf-8") as f:
- json.dump(results, f, indent=2)
- logger.info(f"Results saved to {results_path}")
- except Exception as e:
- logger.error(f"Error saving results JSON: {e}")
- logger.info(f"Script finished. Final test accuracy: {test_accuracy:.4f}")
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment