Guest User

Untitled

a guest
Sep 1st, 2025
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.21 KB | None | 0 0
  1. ```py
  2. # Custom trainer that applies exponential decay to response tokens
  3. class WeightedSFTTrainer(SFTTrainer):
  4. def __init__(self, *args, early_weight=10.0, decay_rate=0.1, min_weight=0.0, skip_special_tokens=True, **kwargs):
  5. super().__init__(*args, **kwargs)
  6. self.early_weight = early_weight
  7. self.decay_rate = decay_rate
  8. self.min_weight = min_weight # <-- ADD THIS LINE
  9. self.skip_special_tokens = skip_special_tokens
  10. # ... rest of the method remains the same
  11.  
  12. # Get special token IDs to skip (template tokens)
  13. self.special_token_ids = set()
  14. if skip_special_tokens and hasattr(self, 'tokenizer'):
  15. tokenizer = self.tokenizer
  16. # Common special tokens to skip
  17. special_tokens = [
  18. '<|im_start|>', '<|im_end|>', '<|endoftext|>',
  19. '<s>', '</s>', '[INST]', '[/INST]',
  20. '▁', '###', '<|assistant|>', '<|user|>',
  21. '<<SYS>>', '<</SYS>>', '[', ']', '<think>', '</think>'
  22. ]
  23. for token in special_tokens:
  24. token_ids = tokenizer.encode(token, add_special_tokens=False)
  25. self.special_token_ids.update(token_ids)
  26. # Also add any additional special tokens from the tokenizer
  27. if hasattr(tokenizer, 'all_special_ids'):
  28. self.special_token_ids.update(tokenizer.all_special_ids)
  29.  
  30. def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
  31. labels = inputs.get("labels")
  32. if labels is None:
  33. return super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
  34.  
  35. # Get logits
  36. outputs = model(**inputs)
  37. logits = outputs.get("logits")
  38.  
  39. # Create weight tensor
  40. weights = torch.ones_like(labels, dtype=torch.float, device=labels.device)
  41.  
  42. # For each sequence in the batch
  43. for seq_idx in range(labels.shape[0]):
  44. seq_labels = labels[seq_idx]
  45. seq_input_ids = inputs["input_ids"][seq_idx]
  46. in_response = False
  47. response_pos = 0
  48. found_content_token = False
  49.  
  50. for pos_idx in range(len(seq_labels)):
  51. if seq_labels[pos_idx] == -100:
  52. # Masked token (prompt or padding)
  53. in_response = False
  54. response_pos = 0
  55. found_content_token = False
  56. weights[seq_idx, pos_idx] = 0.0
  57. else:
  58. # Response token
  59. if not in_response:
  60. in_response = True
  61. response_pos = 0
  62. found_content_token = False
  63.  
  64. # Check if this is a special/template token we should skip
  65. token_id = seq_input_ids[pos_idx + 1] if pos_idx + 1 < len(seq_input_ids) else seq_input_ids[pos_idx]
  66. is_special = token_id in self.special_token_ids
  67.  
  68. if is_special and not found_content_token:
  69. # Skip weighting template tokens at the beginning
  70. weights[seq_idx, pos_idx] = 1.0 # Normal weight for template tokens
  71. else:
  72. found_content_token = True
  73. # Exponential decay: weight = early_weight * exp(-decay_rate * position)
  74. decayed_weight = self.early_weight * torch.exp(torch.tensor(-self.decay_rate * response_pos, device=labels.device))
  75. weight = torch.max(decayed_weight, torch.tensor(self.min_weight, device=labels.device))
  76.  
  77. weights[seq_idx, pos_idx] = weight.item()
  78. response_pos += 1
  79.  
  80. # Compute weighted cross-entropy loss
  81. shift_logits = logits[..., :-1, :].contiguous()
  82. shift_labels = labels[..., 1:].contiguous()
  83. shift_weights = weights[..., 1:].contiguous()
  84.  
  85. loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
  86. flat_logits = shift_logits.view(-1, shift_logits.size(-1))
  87. flat_labels = shift_labels.view(-1)
  88. flat_weights = shift_weights.view(-1)
  89.  
  90. # Compute per-token losses
  91. losses = loss_fct(flat_logits, flat_labels)
  92.  
  93. # Apply weights and mask
  94. mask = (flat_labels != -100).float()
  95. weighted_losses = losses * flat_weights * mask
  96.  
  97. # Average over non-masked tokens
  98. total_loss = weighted_losses.sum()
  99. total_weight = (flat_weights * mask).sum()
  100. loss = total_loss / (total_weight + 1e-8)
  101.  
  102. return (loss, outputs) if return_outputs else loss
  103. ```
  104.  
  105. This is what GPT-5 generated for me. To be honest, I don't understand how it works, but I think it expects that the user input is masked as loss is meant to be computed only for responses. It also has some issues with it using 5-6x more VRAM and taking twice as long to train, but reportedly (according to GPT-5 and Opus):
  106.  
  107. -------
  108.  
  109. ## The 4-5x VRAM explosion
  110.  
  111. **1) KV Cache left enabled (biggest culprit)**
  112. ```python
  113. # Your model likely has this enabled by default
  114. model.config.use_cache = True # YOU NEED TO SET THIS TO FALSE
  115. ```
  116. - During training, this caches key/value tensors for every attention layer for "generation efficiency"
  117. - With your settings: 12k context × batch 8 × ~32 layers × hidden dims = **~10-15 GB extra**
  118. - Stock SFTTrainer automatically disables this. You didn't.
  119.  
  120. **2) Double loss computation**
  121. ```python
  122. # What happens in your compute_loss:
  123. outputs = model(**inputs) # ← Model sees 'labels', computes its own CE loss internally
  124. logits = outputs.logits # ← You then compute ANOTHER loss from the same logits
  125. ```
  126. - The model builds a full forward graph for its internal loss, keeps those activations
  127. - You build a second graph for your custom loss
  128. - **Result**: ~2x memory for loss computation paths
  129.  
  130. **3) Memory fragmentation from thousands of tiny allocations**
  131. ```python
  132. # This line runs thousands of times per batch:
  133. torch.tensor(-self.decay_rate * float(response_pos), device=labels.device)
  134. # And this:
  135. weight_val.item() # Moves scalar from GPU to CPU
  136. ```
  137. - Each ```torch.tensor(..., device=labels.device)``` allocates a tiny GPU block
  138. - Each ```.item()``` forces GPU→CPU transfer
  139. - PyTorch's caching allocator grows to satisfy these, never shrinks during training
  140. - **Result**: "Reserved" memory balloons due to fragmentation
  141.  
  142. **4) Autograd graph pollution**
  143. ```python
  144. # Inside your loss computation (inside autograd context):
  145. decayed_weight = self.early_weight * torch.exp(torch.tensor(-self.decay_rate * response_pos))
  146. ```
  147. - Creating tensors with gradients inside the loss computation
  148. - These get added to the autograd graph unnecessarily
  149. - More tensors kept alive → more memory
  150.  
  151. ## The 50% GPU utilization (sawtooth pattern)
  152.  
  153. **The smoking gun: ```.item()``` calls in nested loops**
  154. ```python
  155. for seq_idx in range(batch_size): # 8 iterations
  156. for pos_idx in range(seq_len): # 12,000 iterations
  157. if seq_labels[pos_idx].item() == -100: # ← GPU→CPU sync
  158. # ...
  159. token_id = int(seq_input_ids[next_pos].item()) # ← Another GPU→CPU sync
  160. weight_val = torch.max(decayed_weight, min_w)
  161. weights[seq_idx, pos_idx] = float(weight_val.item()) # ← Another GPU→CPU sync
  162. ```
  163.  
  164. **What happens:**
  165. - **8 × 12,000 = 96,000** ```.item()``` calls per batch
  166. - Each ```.item()``` forces the GPU to:
  167. 1. Stop current work
  168. 2. Copy one scalar from GPU memory to CPU
  169. 3. Wait for CPU to acknowledge
  170. 4. Resume GPU work
  171. - **Timeline per batch:**
  172. ```
  173. Forward pass: GPU 100% ████████████████
  174. Your loss: GPU 0% (CPU doing .item() loops)
  175. Backward pass: GPU 100% ████████████████
  176. Your loss: GPU 0% (CPU doing more loops)
  177. ```
  178.  
  179. **Why stock SFTTrainer doesn't have this:**
  180. ```python
  181. # Stock SFTTrainer compute_loss (simplified):
  182. def compute_loss(self, model, inputs, return_outputs=False):
  183. outputs = model(**inputs) # Model computes loss internally, returns scalar
  184. return outputs.loss # No Python loops, no .item() calls, no custom processing
  185. ```
  186. - Zero Python-side per-token processing
  187. - Loss computed entirely on GPU with fused kernels
  188. - GPU utilization: steady 100% during forward+backward
  189.  
  190. ## How to fix it
  191.  
  192. **1) Disable KV cache**
  193. ```python
  194. model.config.use_cache = False # Add this before training
  195. ```
  196.  
  197. **2) Avoid double loss computation**
  198. ```python
  199. def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
  200. labels = inputs.pop("labels") # Remove labels so model doesn't compute its own loss
  201. outputs = model(**inputs) # Forward pass only
  202. logits = outputs.logits
  203. # ... your custom loss computation
  204. inputs["labels"] = labels # Put back for logging if needed
  205. ```
  206.  
  207. **3) Move weight computation out of the hot path**
  208. ```python
  209. # Option A: Precompute during dataset preprocessing
  210. # Store weights alongside input_ids/labels in your cached dataset
  211.  
  212. # Option B: Vectorized computation (no Python loops)
  213. @torch.no_grad() # Don't pollute autograd graph
  214. def build_weights_vectorized(labels, input_ids, ...):
  215. # Compute entirely on CPU or with vectorized GPU ops
  216. # No .item() calls, no tiny tensor creation
  217. ```
  218.  
  219. **4) If you must keep the loop approach, do it on CPU:**
  220. ```python
  221. @torch.no_grad()
  222. def build_weights_cpu(labels, input_ids, special_token_ids, early_weight, decay_rate, min_weight):
  223. labels_cpu = labels.cpu() # Move once
  224. input_ids_cpu = input_ids.cpu() # Move once
  225. # ... do all your Python loops on CPU tensors (no .item() needed)
  226. # ... build weights_cpu as regular Python list/numpy
  227. return torch.tensor(weights_cpu, device=labels.device) # Move back once
  228. ```
  229.  
  230. ## The bottom line
  231. - **VRAM explosion**: Mostly KV cache + double loss + memory fragmentation
  232. - **GPU starvation**: 96,000 GPU→CPU syncs per batch from ```.item()``` in nested loops
  233. - **Fix**: Disable cache, avoid double loss, move weight computation out of the training loop
  234.  
  235. With those changes, you should see VRAM drop to ~1.2x stock SFTTrainer and steady GPU utilization.
Advertisement
Add Comment
Please, Sign In to add comment