Guest User

Untitled

a guest
Sep 1st, 2025
42
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.79 KB | None | 0 0
  1. This is the version of the class that Opus refers to in the notes. The logic is the same but it explicitly writes out types:
  2.  
  3. ```py
  4. from typing import Any, Dict, Optional, Set, Tuple, Union
  5.  
  6. import torch
  7. import torch.nn as nn
  8. from trl import SFTTrainer
  9. from transformers.tokenization_utils_base import PreTrainedTokenizerBase
  10. from transformers.utils import ModelOutput
  11.  
  12.  
  13. class WeightedSFTTrainer(SFTTrainer):
  14. """
  15. SFTTrainer with exponential per-token weighting applied to response tokens.
  16. Special/template tokens at the start of the response can be given normal weight.
  17. """
  18.  
  19. early_weight: float
  20. decay_rate: float
  21. min_weight: float
  22. skip_special_tokens: bool
  23. special_token_ids: Set[int]
  24. tokenizer: Optional[PreTrainedTokenizerBase]
  25.  
  26. def __init__(
  27. self,
  28. *args: Any,
  29. early_weight: float = 10.0,
  30. decay_rate: float = 0.1,
  31. min_weight: float = 0.1,
  32. skip_special_tokens: bool = True,
  33. **kwargs: Any,
  34. ) -> None:
  35. super().__init__(*args, **kwargs)
  36. self.early_weight = early_weight
  37. self.decay_rate = decay_rate
  38. self.min_weight = min_weight
  39. self.skip_special_tokens = skip_special_tokens
  40.  
  41. # Get special token IDs to skip (template tokens)
  42. self.special_token_ids = set()
  43. if skip_special_tokens and hasattr(self, "tokenizer"):
  44. tokenizer = self.tokenizer
  45. if tokenizer is not None:
  46. special_tokens = [
  47. "<|im_start|>",
  48. "<|im_end|>",
  49. "<|endoftext|>",
  50. "<s>",
  51. "</s>",
  52. "[INST]",
  53. "[/INST]",
  54. "▁",
  55. "###",
  56. "<|assistant|>",
  57. "<|user|>",
  58. "<<SYS>>",
  59. "<</SYS>>",
  60. "[",
  61. "]",
  62. "<think>",
  63. "</think>",
  64. ]
  65. for token in special_tokens:
  66. token_ids = tokenizer.encode(token, add_special_tokens=False)
  67. self.special_token_ids.update(token_ids)
  68. if hasattr(tokenizer, "all_special_ids") and tokenizer.all_special_ids is not None:
  69. self.special_token_ids.update(tokenizer.all_special_ids)
  70.  
  71. def compute_loss(
  72. self,
  73. model: nn.Module,
  74. inputs: Dict[str, torch.Tensor],
  75. return_outputs: bool = False,
  76. num_items_in_batch: Optional[int] = None,
  77. ) -> Union[torch.Tensor, Tuple[torch.Tensor, Union[ModelOutput, Dict[str, Any], Any]]]:
  78. labels: Optional[torch.Tensor] = inputs.get("labels")
  79. if labels is None:
  80. return super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
  81.  
  82. # Forward pass
  83. outputs: Union[ModelOutput, Dict[str, Any], Any] = model(**inputs)
  84.  
  85. # Get logits regardless of output structure
  86. logits: torch.Tensor
  87. if isinstance(outputs, dict):
  88. logits = outputs.get("logits") # type: ignore[assignment]
  89. else:
  90. logits = getattr(outputs, "logits")
  91.  
  92. # Create weight tensor
  93. weights: torch.Tensor = torch.ones_like(labels, dtype=torch.float, device=labels.device)
  94.  
  95. # For each sequence in the batch
  96. batch_size: int = labels.shape[0]
  97. seq_len: int = labels.shape[1]
  98. input_ids: torch.Tensor = inputs["input_ids"]
  99.  
  100. for seq_idx in range(batch_size):
  101. seq_labels: torch.Tensor = labels[seq_idx]
  102. seq_input_ids: torch.Tensor = input_ids[seq_idx]
  103. in_response: bool = False
  104. response_pos: int = 0
  105. found_content_token: bool = False
  106.  
  107. for pos_idx in range(seq_len):
  108. if seq_labels[pos_idx].item() == -100:
  109. # Masked token (prompt or padding)
  110. in_response = False
  111. response_pos = 0
  112. found_content_token = False
  113. weights[seq_idx, pos_idx] = 0.0
  114. else:
  115. # Response token
  116. if not in_response:
  117. in_response = True
  118. response_pos = 0
  119. found_content_token = False
  120.  
  121. # Check if this is a special/template token we should skip
  122. # Lookahead by one position for labels->next-token prediction alignment
  123. next_pos: int = pos_idx + 1 if (pos_idx + 1) < seq_len else pos_idx
  124. token_id: int = int(seq_input_ids[next_pos].item())
  125. is_special: bool = token_id in self.special_token_ids
  126.  
  127. if is_special and not found_content_token:
  128. # Normal weight for leading template tokens
  129. weights[seq_idx, pos_idx] = 1.0
  130. else:
  131. found_content_token = True
  132. # Exponential decay: weight = early_weight * exp(-decay_rate * position), floored by min_weight
  133. decayed_weight: torch.Tensor = self.early_weight * torch.exp(
  134. torch.tensor(-self.decay_rate * float(response_pos), device=labels.device)
  135. )
  136. min_w: torch.Tensor = torch.tensor(self.min_weight, device=labels.device)
  137. weight_val: torch.Tensor = torch.max(decayed_weight, min_w)
  138. weights[seq_idx, pos_idx] = float(weight_val.item())
  139. response_pos += 1
  140.  
  141. # Compute weighted cross-entropy loss
  142. shift_logits: torch.Tensor = logits[..., :-1, :].contiguous()
  143. shift_labels: torch.Tensor = labels[..., 1:].contiguous()
  144. shift_weights: torch.Tensor = weights[..., 1:].contiguous()
  145.  
  146. loss_fct: nn.CrossEntropyLoss = nn.CrossEntropyLoss(reduction="none")
  147. flat_logits: torch.Tensor = shift_logits.view(-1, shift_logits.size(-1))
  148. flat_labels: torch.Tensor = shift_labels.view(-1)
  149. flat_weights: torch.Tensor = shift_weights.view(-1)
  150.  
  151. # Per-token losses
  152. losses: torch.Tensor = loss_fct(flat_logits, flat_labels)
  153.  
  154. # Apply weights and mask
  155. mask: torch.Tensor = (flat_labels != -100).float()
  156. weighted_losses: torch.Tensor = losses * flat_weights * mask
  157.  
  158. # Average over non-masked tokens
  159. total_loss: torch.Tensor = weighted_losses.sum()
  160. total_weight: torch.Tensor = (flat_weights * mask).sum()
  161. loss: torch.Tensor = total_loss / (total_weight + torch.tensor(1e-8, device=total_weight.device))
  162.  
  163. return (loss, outputs) if return_outputs else loss
  164. ```
Advertisement
Add Comment
Please, Sign In to add comment