Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- ########################################
- # Data Collator For Responses Only
- # Works with a dataset of two columns, input -> output
- # Produces examples with input masked for backprop
- @dataclass
- class DataCollatorForResponsesOnly:
- tokenizer: PreTrainedTokenizerBase
- padding: bool = True
- max_length: int = None
- pad_to_multiple_of: int = None
- def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
- # Extract input and output texts
- inputs = [ex["input"] for ex in examples]
- outputs = [ex["output"] for ex in examples]
- # Tokenize inputs and outputs separately
- input_encodings = self.tokenizer(
- inputs,
- truncation=True,
- max_length=self.max_length,
- padding=False, # Padding is applied after concatenation
- add_special_tokens=False, # Special tokens are handled during concatenation
- return_tensors=None # Return as lists for manual handling
- )
- output_encodings = self.tokenizer(
- outputs,
- truncation=True,
- max_length=self.max_length,
- padding=False,
- add_special_tokens=False,
- return_tensors=None
- )
- batch_input_ids = []
- batch_attention_masks = []
- batch_labels = []
- for i in range(len(examples)):
- input_ids = torch.tensor(input_encodings["input_ids"][i], dtype=torch.long)
- attention_mask = torch.tensor(input_encodings["attention_mask"][i], dtype=torch.long)
- output_ids = torch.tensor(output_encodings["input_ids"][i], dtype=torch.long)
- output_mask = torch.tensor(output_encodings["attention_mask"][i], dtype=torch.long)
- # Concatenate input and output
- concat_input_ids = torch.cat([input_ids, output_ids], dim=0)
- concat_attention_mask = torch.cat([attention_mask, output_mask], dim=0)
- # Create labels: -100 for input tokens, output tokens as-is for the output portion
- labels = torch.full_like(concat_input_ids, -100)
- input_length = input_ids.size(0)
- labels[input_length:] = concat_input_ids[input_length:]
- batch_input_ids.append(concat_input_ids)
- batch_attention_masks.append(concat_attention_mask)
- batch_labels.append(labels)
- # Pad sequences to the same length
- batch_input_ids = torch.nn.utils.rnn.pad_sequence(
- batch_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
- )
- batch_attention_masks = torch.nn.utils.rnn.pad_sequence(
- batch_attention_masks, batch_first=True, padding_value=0
- )
- batch_labels = torch.nn.utils.rnn.pad_sequence(
- batch_labels, batch_first=True, padding_value=-100
- )
- # Return the batched tensors
- return {
- "input_ids": batch_input_ids,
- "attention_mask": batch_attention_masks,
- "labels": batch_labels
- }
- ########################################
- # Quick sanity check to test the data collator and check the masking
- # Initialize the DataCollator
- data_collator = DataCollatorForResponsesOnly(tokenizer=tokenizer, padding=True, max_length=CONTEXT_LENGTH)
- def test_data_collator(data_collator):
- # Mock dataset with input and output columns
- examples = [
- {"input": "What is the capital of France?", "output": "The capital of France is Paris."},
- {"input": "Who wrote Hamlet?", "output": "Hamlet was written by William Shakespeare."}
- ]
- # Collate the data
- batch = data_collator(examples)
- # Decode and print results for easy verification
- input_ids = batch["input_ids"]
- labels = batch["labels"]
- print("=== Input IDs ===")
- for ids in input_ids:
- print(tokenizer.decode(ids.tolist(), skip_special_tokens=False))
- print("\n=== Labels ===")
- for lbls in labels:
- decoded = [tokenizer.decode([id_]) if id_ != -100 else "<mask>" for id_ in lbls.tolist()]
- print(" ".join(decoded))
- # Really important to test the collator!
- test_data_collator(data_collator)
- ########################################
- # Custom trainer, avoiding chat template and dataset processing
- class NoProcessUnslorthTrainer(UnslothTrainer):
- def _prepare_dataset(self, dataset, *args, **kwargs):
- # Could have more side effects, but seems to work fine...
- return dataset
- # Create the trainer
- trainer = NoProcessUnslorthTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = hf_dataset,
- data_collator=data_collator,
- dataset_text_field=None,
- formatting_func=None,
- # Other trainer params ...
- args = UnslothTrainingArguments(
- packing=False,
- remove_unused_columns=False,
- # Other training arguments ...
- ),
- )
Advertisement
Add Comment
Please, Sign In to add comment