Advertisement
pastebin_pj

load_data.py

Sep 29th, 2024
28
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.61 KB | None | 0 0
  1. # %%writefile load_data.py
  2. from transformers import BertTokenizer, BertForSequenceClassification, BatchEncoding, GPT2Tokenizer, AdamW
  3. from sklearn.utils.class_weight import compute_class_weight
  4. from torch.utils.data.distributed import DistributedSampler
  5. from torch.utils.data import DataLoader
  6. from model_builder import ReviewRatingClassifier
  7. from dataclasses import dataclass, field
  8. from typing import Dict, List, Optional
  9. import torch.nn as nn
  10. import pandas as pd
  11. import numpy as np
  12. import datasets
  13. import torch
  14. import os
  15.  
  16. NUM_WORKERS = os.cpu_count()
  17.  
  18.  
  19. @dataclass(frozen=True)
  20. class CreateDataset(torch.utils.data.Dataset):
  21.     reviews: np.ndarray[str]
  22.     labels: np.ndarray[int]
  23.     tokenizer: BertTokenizer | GPT2Tokenizer
  24.     max_len: int
  25.  
  26.     def __len__(self) -> int:
  27.         return len(self.reviews)
  28.  
  29.     def __getitem__(self, item: int) -> Dict[str, torch.Tensor]:
  30.         review = str(self.reviews[item])
  31.         label = self.labels[item]
  32.  
  33.         encoding = self.tokenizer.encode_plus(
  34.             review,
  35.             add_special_tokens=True,
  36.             max_length=self.max_len,
  37.             return_token_type_ids=False,
  38.             padding='max_length',
  39.             truncation=True,
  40.             return_attention_mask=True,
  41.             return_tensors='pt',
  42.         )
  43.  
  44.         return {
  45.             #             'review_text': review,
  46.             'input_ids': encoding['input_ids'],
  47.             'attention_mask': encoding['attention_mask'],
  48.             'labels': torch.tensor(label, dtype=torch.long)
  49.         }
  50.  
  51.  
  52. def create_datasets(x: List[str], y: List[int], tokenizer: BertTokenizer | GPT2Tokenizer) -> datasets.Dataset:
  53.     # Assuming you have X_train and y_train as lists or numpy arrays
  54.     data = {'review': x, 'label': y}
  55.     df_train = pd.DataFrame(data)
  56.  
  57.     # Convert to Hugging Face datasets.Dataset
  58.     _dataset = datasets.Dataset.from_pandas(df_train)
  59.  
  60.     # Repeat for validation data
  61.     # data_val = {'review': x_val, 'label': y_val}
  62.     # df_val = pd.DataFrame(data_val)
  63.     # val_dataset = datasets.Dataset.from_pandas(df_val)
  64.  
  65.     # Tokenize the dataset using map
  66.     def tokenize_function(example) -> BatchEncoding:
  67.         return tokenizer(example['review'], padding="max_length", truncation=True, max_length=128)
  68.  
  69.     _dataset = _dataset.map(tokenize_function, batched=True)
  70.     # val_dataset = val_dataset.map(tokenize_function, batched=True)
  71.  
  72.     # Remove columns that are not tensors (e.g., 'review' since Trainer expects tensors)
  73.     _dataset = _dataset.remove_columns(['review'])
  74.     # val_dataset = val_dataset.remove_columns(['review'])
  75.  
  76.     # Set format for PyTorch tensors
  77.     _dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
  78.     # val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
  79.  
  80.     return _dataset
  81.  
  82.  
  83. # Custom learning rate scheduler
  84. @dataclass(frozen=True)
  85. class CyclicLRScheduler(torch.optim.lr_scheduler._LRScheduler):
  86.     optimizer: torch.optim.Optimizer
  87.     base_lr: float
  88.     max_lr: float
  89.     step_size_up: int = 2000
  90.     step_size_down: int = None
  91.     mode: str = 'triangular'
  92.     gamma: float = 1.
  93.     scale_fn = None
  94.     scale_mode: str = 'cycle'
  95.     cycle_momentum: bool = True
  96.     base_momentum: float = 0.8
  97.     max_momentum: float = 0.9
  98.     last_epoch: int = -1
  99.  
  100.     # Use default_factory to create these lists when the instance is created
  101.     base_lrs: List[float] = field(init=False)
  102.     max_lrs: List[float] = field(init=False)
  103.     base_momentums: List[float] = field(init=False)
  104.     max_momentums: List[float] = field(init=False)
  105.     total_size: int = field(init=False)
  106.  
  107.     def __post_init__(self):
  108.         # We use __setattr__ because the class is frozen
  109.         object.__setattr__(self, 'base_lrs', [self.base_lr] * len(self.optimizer.param_groups))
  110.         object.__setattr__(self, 'max_lrs', [self.max_lr] * len(self.optimizer.param_groups))
  111.         object.__setattr__(self, 'step_size_down',
  112.                            self.step_size_down if self.step_size_down is not None else self.step_size_up)
  113.         object.__setattr__(self, 'total_size', self.step_size_up + self.step_size_down)
  114.         object.__setattr__(self, 'base_momentums', [self.base_momentum] * len(self.optimizer.param_groups))
  115.         object.__setattr__(self, 'max_momentums', [self.max_momentum] * len(self.optimizer.param_groups))
  116.  
  117.     @property
  118.     def get_lr(self) -> List[float]:
  119.         cycle = np.floor(1 + self.last_epoch / self.total_size)
  120.         x = 1 + self.last_epoch / self.total_size - cycle
  121.         if x <= 0.5:
  122.             scale_factor = x * 2
  123.         else:
  124.             scale_factor = (1 - x) * 2
  125.  
  126.         lrs = []
  127.         for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
  128.             base_height = (max_lr - base_lr) * scale_factor
  129.             lr = base_lr + base_height
  130.             lrs.append(lr)
  131.  
  132.         return lrs
  133.  
  134.  
  135. def check_dataset_shapes(dataset: datasets.Dataset) -> None:
  136.     for idx in range(len(dataset)):
  137.         item = dataset[idx]
  138.         assert item['input_ids'].shape[1] == 40, f"\nError: input_ids at index {idx} is not 40 in length"
  139.         assert item['attention_mask'].shape[1] == 40, f"\nError: attention_mask at index {idx} is not 40 in length"
  140.  
  141.  
  142. def load_data_objs(
  143.         batch_size: int,
  144.         rank: int,
  145.         world_size: int,
  146.         epochs: int,
  147.         x_train_path: str,
  148.         y_train_path: str,
  149.         x_val_path: str,
  150.         y_val_path: str,
  151.         gpu: bool,
  152.         gpu_id: int,
  153.         learning_rate: float,
  154.         num_workers: int,
  155.         lr_scheduler: Optional[str] = None,
  156. ) -> tuple[DataLoader, DataLoader, nn.Module, nn.CrossEntropyLoss, torch.optim.Optimizer, Optional[torch.optim.lr_scheduler._LRScheduler]]:
  157.     def load_tensor(path: str, name: str) -> np.ndarray[str] | np.ndarray[int]:
  158.         if not os.path.isfile(path):
  159.             raise FileNotFoundError(f"{name} file not found: {path}")
  160.         try:
  161.             return np.load(path, allow_pickle=True)
  162.         except Exception as er:
  163.             raise RuntimeError(f"Error loading {name} from {path}: {str(er)}")
  164.  
  165.     try:
  166.         xtrain = load_tensor(x_train_path, "X_train.npy")
  167.         ytrain = load_tensor(y_train_path, "y_train.npy")
  168.         xval = load_tensor(x_val_path, "X_val.npy")
  169.         yval = load_tensor(y_val_path, "y_val.npy")
  170.     except Exception as e:
  171.         print(f"Error loading data: {str(e)}")
  172.         raise
  173.  
  174.     # Ensure that the number of reviews matches the number of labels
  175.     assert len(xtrain) == len(ytrain), "Mismatch between X_train and y_train lengths"
  176.     assert len(xval) == len(yval), "Mismatch between X_val and y_val lengths"
  177.  
  178.     # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
  179.     tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
  180.     tokenizer.pad_token = tokenizer.eos_token
  181.     # Compute class weights
  182.     class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(ytrain), y=ytrain)
  183.     # model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
  184.     model: ReviewRatingClassifier = ReviewRatingClassifier(num_classes=5, unfreeze_layers=10)
  185.     # optimizer = (torch.optim.AdamW(params=model.parameters(), lr=learning_rate, weight_decay=1e-4))
  186.     optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
  187.     # train_dts: torch.utils.data.Dataset = CreateDataset(xtrain, ytrain, tokenizer, 128)
  188.     # val_dts: torch.utils.data.Dataset = CreateDataset(xval, yval, tokenizer, 128)
  189.     train_dts = create_datasets(xtrain, ytrain, tokenizer)
  190.     val_dts = create_datasets(xval, yval, tokenizer)
  191.     #     check_dataset_shapes(train_dts)
  192.  
  193.     if gpu:
  194.         criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).float().to(gpu_id))
  195.         # criterion = nn.CrossEntropyLoss()
  196.         train_dtl = DataLoader(train_dts, batch_size=batch_size, shuffle=False, pin_memory=True,
  197.                                sampler=DistributedSampler(
  198.                                    train_dts, num_replicas=world_size, rank=rank), num_workers=num_workers, )
  199.         val_dtl = DataLoader(val_dts, batch_size=1, shuffle=False, pin_memory=True, sampler=DistributedSampler(
  200.             val_dts, num_replicas=world_size, rank=rank), num_workers=num_workers, )
  201.     else:
  202.         criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).float())
  203.         train_dtl = DataLoader(train_dts, batch_size=batch_size,
  204.                                shuffle=False, pin_memory=True, num_workers=num_workers, )
  205.         val_dtl = DataLoader(val_dts, batch_size=batch_size,
  206.                              shuffle=False, pin_memory=True, num_workers=num_workers, )
  207.  
  208.     scheduler = None
  209.     if lr_scheduler:
  210.         LR_SCHEDULER = {
  211.             "cyclic_lr": CyclicLRScheduler(optimizer, base_lr=0.0001, max_lr=0.01, step_size_up=2000, mode='min'),
  212.             # requires metric to step
  213.             "reduce_lr": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2),
  214.             "one_cycle_lr": torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, epochs=epochs,
  215.                                                                 steps_per_epoch=len(train_dtl), anneal_strategy='cos'),
  216.             "cosine": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
  217.         }
  218.         if lr_scheduler in LR_SCHEDULER:
  219.             scheduler = LR_SCHEDULER[lr_scheduler]
  220.         else:
  221.             raise ValueError(f"""Invalid lr_scheduler value: {
  222.            lr_scheduler}. Valid options are: {list(LR_SCHEDULER.keys())}""")
  223.  
  224.     return train_dtl, val_dtl, model, criterion, optimizer, scheduler
  225.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement