Guest User

Untitled

a guest
Dec 12th, 2024
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.88 KB | None | 0 0
  1. ########################################
  2. # Data Collator For Responses Only
  3. #   Works with a dataset of two columns, input -> output
  4. #   Produces examples with input masked for backprop
  5.  
  6. @dataclass
  7. class DataCollatorForResponsesOnly:
  8.     tokenizer: PreTrainedTokenizerBase
  9.     padding: bool = True
  10.     max_length: int = None
  11.     pad_to_multiple_of: int = None
  12.  
  13.     def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
  14.         # Extract input and output texts
  15.         inputs = [ex["input"] for ex in examples]
  16.         outputs = [ex["output"] for ex in examples]
  17.  
  18.         # Tokenize inputs and outputs separately
  19.         input_encodings = self.tokenizer(
  20.             inputs,
  21.             truncation=True,
  22.             max_length=self.max_length,
  23.             padding=False,  # Padding is applied after concatenation
  24.             add_special_tokens=False,  # Special tokens are handled during concatenation
  25.             return_tensors=None  # Return as lists for manual handling
  26.         )
  27.  
  28.         output_encodings = self.tokenizer(
  29.             outputs,
  30.             truncation=True,
  31.             max_length=self.max_length,
  32.             padding=False,
  33.             add_special_tokens=False,
  34.             return_tensors=None
  35.         )
  36.  
  37.         batch_input_ids = []
  38.         batch_attention_masks = []
  39.         batch_labels = []
  40.  
  41.         for i in range(len(examples)):
  42.             input_ids = torch.tensor(input_encodings["input_ids"][i], dtype=torch.long)
  43.             attention_mask = torch.tensor(input_encodings["attention_mask"][i], dtype=torch.long)
  44.             output_ids = torch.tensor(output_encodings["input_ids"][i], dtype=torch.long)
  45.             output_mask = torch.tensor(output_encodings["attention_mask"][i], dtype=torch.long)
  46.  
  47.             # Concatenate input and output
  48.             concat_input_ids = torch.cat([input_ids, output_ids], dim=0)
  49.             concat_attention_mask = torch.cat([attention_mask, output_mask], dim=0)
  50.  
  51.             # Create labels: -100 for input tokens, output tokens as-is for the output portion
  52.             labels = torch.full_like(concat_input_ids, -100)
  53.             input_length = input_ids.size(0)
  54.             labels[input_length:] = concat_input_ids[input_length:]
  55.  
  56.             batch_input_ids.append(concat_input_ids)
  57.             batch_attention_masks.append(concat_attention_mask)
  58.             batch_labels.append(labels)
  59.  
  60.         # Pad sequences to the same length
  61.         batch_input_ids = torch.nn.utils.rnn.pad_sequence(
  62.             batch_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
  63.         )
  64.         batch_attention_masks = torch.nn.utils.rnn.pad_sequence(
  65.             batch_attention_masks, batch_first=True, padding_value=0
  66.         )
  67.         batch_labels = torch.nn.utils.rnn.pad_sequence(
  68.             batch_labels, batch_first=True, padding_value=-100
  69.         )
  70.  
  71.         # Return the batched tensors
  72.         return {
  73.             "input_ids": batch_input_ids,
  74.             "attention_mask": batch_attention_masks,
  75.             "labels": batch_labels
  76.         }
  77.  
  78. ########################################
  79. # Quick sanity check to test the data collator and check the masking
  80.  
  81. # Initialize the DataCollator
  82. data_collator = DataCollatorForResponsesOnly(tokenizer=tokenizer, padding=True, max_length=CONTEXT_LENGTH)
  83.  
  84. def test_data_collator(data_collator):
  85.     # Mock dataset with input and output columns
  86.     examples = [
  87.         {"input": "What is the capital of France?", "output": "The capital of France is Paris."},
  88.         {"input": "Who wrote Hamlet?", "output": "Hamlet was written by William Shakespeare."}
  89.     ]
  90.  
  91.     # Collate the data
  92.     batch = data_collator(examples)
  93.  
  94.     # Decode and print results for easy verification
  95.     input_ids = batch["input_ids"]
  96.     labels = batch["labels"]
  97.  
  98.     print("=== Input IDs ===")
  99.     for ids in input_ids:
  100.         print(tokenizer.decode(ids.tolist(), skip_special_tokens=False))
  101.  
  102.     print("\n=== Labels ===")
  103.     for lbls in labels:
  104.         decoded = [tokenizer.decode([id_]) if id_ != -100 else "<mask>" for id_ in lbls.tolist()]
  105.         print(" ".join(decoded))
  106.  
  107. # Really important to test the collator!
  108. test_data_collator(data_collator)
  109.  
  110.  
  111. ########################################
  112. # Custom trainer, avoiding chat template and dataset processing
  113.  
  114. class NoProcessUnslorthTrainer(UnslothTrainer):
  115.     def _prepare_dataset(self, dataset, *args, **kwargs):
  116.         # Could have more side effects, but seems to work fine...
  117.         return dataset
  118.  
  119. # Create the trainer
  120. trainer = NoProcessUnslorthTrainer(
  121.     model = model,
  122.     tokenizer = tokenizer,
  123.     train_dataset = hf_dataset,
  124.  
  125.     data_collator=data_collator,
  126.  
  127.     dataset_text_field=None,
  128.     formatting_func=None,
  129.  
  130.     # Other trainer params ...
  131.  
  132.     args = UnslothTrainingArguments(
  133.         packing=False,
  134.         remove_unused_columns=False,
  135.  
  136.         # Other training arguments ...
  137.     ),
  138. )
Advertisement
Add Comment
Please, Sign In to add comment