Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- ```py
- # Custom trainer that applies exponential decay to response tokens
- class WeightedSFTTrainer(SFTTrainer):
- def __init__(self, *args, early_weight=10.0, decay_rate=0.1, min_weight=0.0, skip_special_tokens=True, **kwargs):
- super().__init__(*args, **kwargs)
- self.early_weight = early_weight
- self.decay_rate = decay_rate
- self.min_weight = min_weight # <-- ADD THIS LINE
- self.skip_special_tokens = skip_special_tokens
- # ... rest of the method remains the same
- # Get special token IDs to skip (template tokens)
- self.special_token_ids = set()
- if skip_special_tokens and hasattr(self, 'tokenizer'):
- tokenizer = self.tokenizer
- # Common special tokens to skip
- special_tokens = [
- '<|im_start|>', '<|im_end|>', '<|endoftext|>',
- '<s>', '</s>', '[INST]', '[/INST]',
- '▁', '###', '<|assistant|>', '<|user|>',
- '<<SYS>>', '<</SYS>>', '[', ']', '<think>', '</think>'
- ]
- for token in special_tokens:
- token_ids = tokenizer.encode(token, add_special_tokens=False)
- self.special_token_ids.update(token_ids)
- # Also add any additional special tokens from the tokenizer
- if hasattr(tokenizer, 'all_special_ids'):
- self.special_token_ids.update(tokenizer.all_special_ids)
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
- labels = inputs.get("labels")
- if labels is None:
- return super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
- # Get logits
- outputs = model(**inputs)
- logits = outputs.get("logits")
- # Create weight tensor
- weights = torch.ones_like(labels, dtype=torch.float, device=labels.device)
- # For each sequence in the batch
- for seq_idx in range(labels.shape[0]):
- seq_labels = labels[seq_idx]
- seq_input_ids = inputs["input_ids"][seq_idx]
- in_response = False
- response_pos = 0
- found_content_token = False
- for pos_idx in range(len(seq_labels)):
- if seq_labels[pos_idx] == -100:
- # Masked token (prompt or padding)
- in_response = False
- response_pos = 0
- found_content_token = False
- weights[seq_idx, pos_idx] = 0.0
- else:
- # Response token
- if not in_response:
- in_response = True
- response_pos = 0
- found_content_token = False
- # Check if this is a special/template token we should skip
- token_id = seq_input_ids[pos_idx + 1] if pos_idx + 1 < len(seq_input_ids) else seq_input_ids[pos_idx]
- is_special = token_id in self.special_token_ids
- if is_special and not found_content_token:
- # Skip weighting template tokens at the beginning
- weights[seq_idx, pos_idx] = 1.0 # Normal weight for template tokens
- else:
- found_content_token = True
- # Exponential decay: weight = early_weight * exp(-decay_rate * position)
- decayed_weight = self.early_weight * torch.exp(torch.tensor(-self.decay_rate * response_pos, device=labels.device))
- weight = torch.max(decayed_weight, torch.tensor(self.min_weight, device=labels.device))
- weights[seq_idx, pos_idx] = weight.item()
- response_pos += 1
- # Compute weighted cross-entropy loss
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- shift_weights = weights[..., 1:].contiguous()
- loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
- flat_logits = shift_logits.view(-1, shift_logits.size(-1))
- flat_labels = shift_labels.view(-1)
- flat_weights = shift_weights.view(-1)
- # Compute per-token losses
- losses = loss_fct(flat_logits, flat_labels)
- # Apply weights and mask
- mask = (flat_labels != -100).float()
- weighted_losses = losses * flat_weights * mask
- # Average over non-masked tokens
- total_loss = weighted_losses.sum()
- total_weight = (flat_weights * mask).sum()
- loss = total_loss / (total_weight + 1e-8)
- return (loss, outputs) if return_outputs else loss
- ```
- This is what GPT-5 generated for me. To be honest, I don't understand how it works, but I think it expects that the user input is masked as loss is meant to be computed only for responses. It also has some issues with it using 5-6x more VRAM and taking twice as long to train, but reportedly (according to GPT-5 and Opus):
- -------
- ## The 4-5x VRAM explosion
- **1) KV Cache left enabled (biggest culprit)**
- ```python
- # Your model likely has this enabled by default
- model.config.use_cache = True # YOU NEED TO SET THIS TO FALSE
- ```
- - During training, this caches key/value tensors for every attention layer for "generation efficiency"
- - With your settings: 12k context × batch 8 × ~32 layers × hidden dims = **~10-15 GB extra**
- - Stock SFTTrainer automatically disables this. You didn't.
- **2) Double loss computation**
- ```python
- # What happens in your compute_loss:
- outputs = model(**inputs) # ← Model sees 'labels', computes its own CE loss internally
- logits = outputs.logits # ← You then compute ANOTHER loss from the same logits
- ```
- - The model builds a full forward graph for its internal loss, keeps those activations
- - You build a second graph for your custom loss
- - **Result**: ~2x memory for loss computation paths
- **3) Memory fragmentation from thousands of tiny allocations**
- ```python
- # This line runs thousands of times per batch:
- torch.tensor(-self.decay_rate * float(response_pos), device=labels.device)
- # And this:
- weight_val.item() # Moves scalar from GPU to CPU
- ```
- - Each ```torch.tensor(..., device=labels.device)``` allocates a tiny GPU block
- - Each ```.item()``` forces GPU→CPU transfer
- - PyTorch's caching allocator grows to satisfy these, never shrinks during training
- - **Result**: "Reserved" memory balloons due to fragmentation
- **4) Autograd graph pollution**
- ```python
- # Inside your loss computation (inside autograd context):
- decayed_weight = self.early_weight * torch.exp(torch.tensor(-self.decay_rate * response_pos))
- ```
- - Creating tensors with gradients inside the loss computation
- - These get added to the autograd graph unnecessarily
- - More tensors kept alive → more memory
- ## The 50% GPU utilization (sawtooth pattern)
- **The smoking gun: ```.item()``` calls in nested loops**
- ```python
- for seq_idx in range(batch_size): # 8 iterations
- for pos_idx in range(seq_len): # 12,000 iterations
- if seq_labels[pos_idx].item() == -100: # ← GPU→CPU sync
- # ...
- token_id = int(seq_input_ids[next_pos].item()) # ← Another GPU→CPU sync
- weight_val = torch.max(decayed_weight, min_w)
- weights[seq_idx, pos_idx] = float(weight_val.item()) # ← Another GPU→CPU sync
- ```
- **What happens:**
- - **8 × 12,000 = 96,000** ```.item()``` calls per batch
- - Each ```.item()``` forces the GPU to:
- 1. Stop current work
- 2. Copy one scalar from GPU memory to CPU
- 3. Wait for CPU to acknowledge
- 4. Resume GPU work
- - **Timeline per batch:**
- ```
- Forward pass: GPU 100% ████████████████
- Your loss: GPU 0% (CPU doing .item() loops)
- Backward pass: GPU 100% ████████████████
- Your loss: GPU 0% (CPU doing more loops)
- ```
- **Why stock SFTTrainer doesn't have this:**
- ```python
- # Stock SFTTrainer compute_loss (simplified):
- def compute_loss(self, model, inputs, return_outputs=False):
- outputs = model(**inputs) # Model computes loss internally, returns scalar
- return outputs.loss # No Python loops, no .item() calls, no custom processing
- ```
- - Zero Python-side per-token processing
- - Loss computed entirely on GPU with fused kernels
- - GPU utilization: steady 100% during forward+backward
- ## How to fix it
- **1) Disable KV cache**
- ```python
- model.config.use_cache = False # Add this before training
- ```
- **2) Avoid double loss computation**
- ```python
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
- labels = inputs.pop("labels") # Remove labels so model doesn't compute its own loss
- outputs = model(**inputs) # Forward pass only
- logits = outputs.logits
- # ... your custom loss computation
- inputs["labels"] = labels # Put back for logging if needed
- ```
- **3) Move weight computation out of the hot path**
- ```python
- # Option A: Precompute during dataset preprocessing
- # Store weights alongside input_ids/labels in your cached dataset
- # Option B: Vectorized computation (no Python loops)
- @torch.no_grad() # Don't pollute autograd graph
- def build_weights_vectorized(labels, input_ids, ...):
- # Compute entirely on CPU or with vectorized GPU ops
- # No .item() calls, no tiny tensor creation
- ```
- **4) If you must keep the loop approach, do it on CPU:**
- ```python
- @torch.no_grad()
- def build_weights_cpu(labels, input_ids, special_token_ids, early_weight, decay_rate, min_weight):
- labels_cpu = labels.cpu() # Move once
- input_ids_cpu = input_ids.cpu() # Move once
- # ... do all your Python loops on CPU tensors (no .item() needed)
- # ... build weights_cpu as regular Python list/numpy
- return torch.tensor(weights_cpu, device=labels.device) # Move back once
- ```
- ## The bottom line
- - **VRAM explosion**: Mostly KV cache + double loss + memory fragmentation
- - **GPU starvation**: 96,000 GPU→CPU syncs per batch from ```.item()``` in nested loops
- - **Fix**: Disable cache, avoid double loss, move weight computation out of the training loop
- With those changes, you should see VRAM drop to ~1.2x stock SFTTrainer and steady GPU utilization.
Advertisement
Add Comment
Please, Sign In to add comment