Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- This is the version of the class that Opus refers to in the notes. The logic is the same but it explicitly writes out types:
- ```py
- from typing import Any, Dict, Optional, Set, Tuple, Union
- import torch
- import torch.nn as nn
- from trl import SFTTrainer
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase
- from transformers.utils import ModelOutput
- class WeightedSFTTrainer(SFTTrainer):
- """
- SFTTrainer with exponential per-token weighting applied to response tokens.
- Special/template tokens at the start of the response can be given normal weight.
- """
- early_weight: float
- decay_rate: float
- min_weight: float
- skip_special_tokens: bool
- special_token_ids: Set[int]
- tokenizer: Optional[PreTrainedTokenizerBase]
- def __init__(
- self,
- *args: Any,
- early_weight: float = 10.0,
- decay_rate: float = 0.1,
- min_weight: float = 0.1,
- skip_special_tokens: bool = True,
- **kwargs: Any,
- ) -> None:
- super().__init__(*args, **kwargs)
- self.early_weight = early_weight
- self.decay_rate = decay_rate
- self.min_weight = min_weight
- self.skip_special_tokens = skip_special_tokens
- # Get special token IDs to skip (template tokens)
- self.special_token_ids = set()
- if skip_special_tokens and hasattr(self, "tokenizer"):
- tokenizer = self.tokenizer
- if tokenizer is not None:
- special_tokens = [
- "<|im_start|>",
- "<|im_end|>",
- "<|endoftext|>",
- "<s>",
- "</s>",
- "[INST]",
- "[/INST]",
- "▁",
- "###",
- "<|assistant|>",
- "<|user|>",
- "<<SYS>>",
- "<</SYS>>",
- "[",
- "]",
- "<think>",
- "</think>",
- ]
- for token in special_tokens:
- token_ids = tokenizer.encode(token, add_special_tokens=False)
- self.special_token_ids.update(token_ids)
- if hasattr(tokenizer, "all_special_ids") and tokenizer.all_special_ids is not None:
- self.special_token_ids.update(tokenizer.all_special_ids)
- def compute_loss(
- self,
- model: nn.Module,
- inputs: Dict[str, torch.Tensor],
- return_outputs: bool = False,
- num_items_in_batch: Optional[int] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, Union[ModelOutput, Dict[str, Any], Any]]]:
- labels: Optional[torch.Tensor] = inputs.get("labels")
- if labels is None:
- return super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
- # Forward pass
- outputs: Union[ModelOutput, Dict[str, Any], Any] = model(**inputs)
- # Get logits regardless of output structure
- logits: torch.Tensor
- if isinstance(outputs, dict):
- logits = outputs.get("logits") # type: ignore[assignment]
- else:
- logits = getattr(outputs, "logits")
- # Create weight tensor
- weights: torch.Tensor = torch.ones_like(labels, dtype=torch.float, device=labels.device)
- # For each sequence in the batch
- batch_size: int = labels.shape[0]
- seq_len: int = labels.shape[1]
- input_ids: torch.Tensor = inputs["input_ids"]
- for seq_idx in range(batch_size):
- seq_labels: torch.Tensor = labels[seq_idx]
- seq_input_ids: torch.Tensor = input_ids[seq_idx]
- in_response: bool = False
- response_pos: int = 0
- found_content_token: bool = False
- for pos_idx in range(seq_len):
- if seq_labels[pos_idx].item() == -100:
- # Masked token (prompt or padding)
- in_response = False
- response_pos = 0
- found_content_token = False
- weights[seq_idx, pos_idx] = 0.0
- else:
- # Response token
- if not in_response:
- in_response = True
- response_pos = 0
- found_content_token = False
- # Check if this is a special/template token we should skip
- # Lookahead by one position for labels->next-token prediction alignment
- next_pos: int = pos_idx + 1 if (pos_idx + 1) < seq_len else pos_idx
- token_id: int = int(seq_input_ids[next_pos].item())
- is_special: bool = token_id in self.special_token_ids
- if is_special and not found_content_token:
- # Normal weight for leading template tokens
- weights[seq_idx, pos_idx] = 1.0
- else:
- found_content_token = True
- # Exponential decay: weight = early_weight * exp(-decay_rate * position), floored by min_weight
- decayed_weight: torch.Tensor = self.early_weight * torch.exp(
- torch.tensor(-self.decay_rate * float(response_pos), device=labels.device)
- )
- min_w: torch.Tensor = torch.tensor(self.min_weight, device=labels.device)
- weight_val: torch.Tensor = torch.max(decayed_weight, min_w)
- weights[seq_idx, pos_idx] = float(weight_val.item())
- response_pos += 1
- # Compute weighted cross-entropy loss
- shift_logits: torch.Tensor = logits[..., :-1, :].contiguous()
- shift_labels: torch.Tensor = labels[..., 1:].contiguous()
- shift_weights: torch.Tensor = weights[..., 1:].contiguous()
- loss_fct: nn.CrossEntropyLoss = nn.CrossEntropyLoss(reduction="none")
- flat_logits: torch.Tensor = shift_logits.view(-1, shift_logits.size(-1))
- flat_labels: torch.Tensor = shift_labels.view(-1)
- flat_weights: torch.Tensor = shift_weights.view(-1)
- # Per-token losses
- losses: torch.Tensor = loss_fct(flat_logits, flat_labels)
- # Apply weights and mask
- mask: torch.Tensor = (flat_labels != -100).float()
- weighted_losses: torch.Tensor = losses * flat_weights * mask
- # Average over non-masked tokens
- total_loss: torch.Tensor = weighted_losses.sum()
- total_weight: torch.Tensor = (flat_weights * mask).sum()
- loss: torch.Tensor = total_loss / (total_weight + torch.tensor(1e-8, device=total_weight.device))
- return (loss, outputs) if return_outputs else loss
- ```
Advertisement
Add Comment
Please, Sign In to add comment