Guest User

Untitled

a guest
Jul 26th, 2025
25
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.11 KB | None | 0 0
  1. #!/usr/bin/env python3
  2. import os
  3. import re
  4. import torch
  5. # ВАЖНО: unsloth надо импортировать до transformers/peft
  6. import unsloth
  7. from unsloth import FastLanguageModel
  8. from datasets import Dataset
  9. from transformers import Trainer, TrainingArguments, AutoTokenizer
  10. from typing import List, Dict
  11. from dataclasses import dataclass
  12.  
  13. # ─── CONFIG ────────────────────────────────────────────────────────────────
  14. BASE_MODEL_DIR = "./Vikhr-Nemo-12B-Instruct-R-21-09-24"
  15. DATA_FILE      = "./dataset.txt"
  16. OUTPUT_DIR     = "./lora_out"
  17.  
  18. # Жёстко ограничиваем max context длиной 2048
  19. MAX_LEN = 2048
  20.  
  21. def bf16_ok():
  22.     return torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8
  23.  
  24. dtype = torch.bfloat16 if bf16_ok() else torch.float16
  25. print("Using dtype =", dtype)
  26.  
  27. # ─── 1. ЗАГРУЖАЕМ ДАННЫЕ ─────────────────────────────────────────────────────
  28. with open(DATA_FILE, encoding="utf-8") as f:
  29.     raw = f.read().strip().split("\n\n\n")
  30. dataset = Dataset.from_list([{"text": block} for block in raw])
  31. print(f"Loaded {len(dataset)} dialogues")
  32.  
  33. # ─── 2. ИНИЦИАЛИЗАЦИЯ МОДЕЛИ ─────────────────────────────────────────────────
  34. model, tokenizer = FastLanguageModel.from_pretrained(
  35.     BASE_MODEL_DIR,
  36.     dtype=dtype,
  37.     load_in_4bit=False
  38. )
  39.  
  40. # Принудительно ставим модельный max_length = 2048
  41. tokenizer.model_max_length = MAX_LEN
  42. print("Forced model max_length =", MAX_LEN)
  43.  
  44. # ─── 3. ВСТАВЛЯЕМ LoRA ───────────────────────────────────────────────────────
  45. model = FastLanguageModel.get_peft_model(
  46.     model,
  47.     r              = 16,
  48.     lora_alpha     = 32,
  49.     target_modules = [
  50.         "q_proj","k_proj","v_proj","o_proj",
  51.         "gate_proj","up_proj","down_proj"
  52.     ],
  53.     lora_dropout   = 0.05,
  54.     bias           = "none"
  55. )
  56.  
  57. # ─── 4. ТОКЕНИЗАЦИЯ + МАСКИРОВКА ─────────────────────────────────────────────
  58. speaker_re = re.compile(r"(USER:|ASSISTANT:)", flags=re.IGNORECASE)
  59.  
  60. def tok_fn(batch: Dict[str, List[str]]) -> Dict[str, List[List[int]]]:
  61.     all_input_ids = []
  62.     all_attn_masks = []
  63.     all_labels = []
  64.  
  65.     for text in batch["text"]:
  66.         # Разбиваем на сегменты по USER:/ASSISTANT:
  67.         parts = speaker_re.split(text)
  68.         utterances = []
  69.         for i in range(1, len(parts), 2):
  70.             speaker = parts[i][:-1].upper()  # "USER" или "ASSISTANT"
  71.             content = parts[i] + parts[i+1]
  72.             utterances.append((speaker, content))
  73.  
  74.         ids: List[int] = []
  75.         masks: List[int] = []
  76.         labels: List[int] = []
  77.  
  78.         for speaker, utt in utterances:
  79.             toks = tokenizer(utt, add_special_tokens=False)["input_ids"]
  80.             for tid in toks:
  81.                 ids.append(tid)
  82.                 masks.append(1)
  83.                 labels.append(tid if speaker == "ASSISTANT" else -100)
  84.  
  85.         # EOS токен (игнорим в лейблах)
  86.         if tokenizer.eos_token_id is not None:
  87.             ids.append(tokenizer.eos_token_id)
  88.             masks.append(1)
  89.             labels.append(-100)
  90.  
  91.         # Жёстко обрезаем всё до MAX_LEN
  92.         if len(ids) > MAX_LEN:
  93.             ids    = ids[:MAX_LEN]
  94.             masks  = masks[:MAX_LEN]
  95.             labels = labels[:MAX_LEN]
  96.  
  97.         all_input_ids.append(ids)
  98.         all_attn_masks.append(masks)
  99.         all_labels.append(labels)
  100.  
  101.     return {
  102.         "input_ids": all_input_ids,
  103.         "attention_mask": all_attn_masks,
  104.         "labels": all_labels,
  105.     }
  106.  
  107. tokenized = dataset.map(
  108.     tok_fn,
  109.     batched=True,
  110.     remove_columns=["text"]
  111. )
  112. print("Tokenized; sample lengths:",
  113.       len(tokenized[0]["input_ids"]), len(tokenized[0]["labels"]))
  114.  
  115. # ─── 5. COLLATОР С ОБРЕЗКОЙ ──────────────────────────────────────────────────
  116. @dataclass
  117. class CollatorWithTruncation:
  118.     tokenizer: AutoTokenizer
  119.     max_length: int
  120.  
  121.     def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
  122.         # Паддинг по longest, но все не более max_length
  123.         batch = self.tokenizer.pad(
  124.             features,
  125.             padding="longest",
  126.             return_tensors="pt"
  127.         )
  128.         bsz, seq_len = batch["input_ids"].shape
  129.         for key, pad_value in [
  130.             ("input_ids", self.tokenizer.pad_token_id),
  131.             ("attention_mask", 0),
  132.             ("labels", -100)
  133.         ]:
  134.             data = batch[key]
  135.             if seq_len > self.max_length:
  136.                 data = data[:, :self.max_length]
  137.             elif seq_len < self.max_length:
  138.                 pad_shape = (bsz, self.max_length - seq_len)
  139.                 pad_tensor = torch.full(pad_shape, pad_value, dtype=data.dtype)
  140.                 data = torch.cat([data, pad_tensor], dim=1)
  141.             batch[key] = data
  142.         return batch
  143.  
  144. collator = CollatorWithTruncation(tokenizer, MAX_LEN)
  145.  
  146. # ─── 6. TRAINING ARGS ───────────────────────────────────────────────────────
  147. args = TrainingArguments(
  148.     output_dir                  = OUTPUT_DIR,
  149.     per_device_train_batch_size = 1,
  150.     gradient_accumulation_steps = 4,
  151.     learning_rate               = 2e-4,
  152.     num_train_epochs            = 1,
  153.     warmup_ratio                = 0.03,
  154.     lr_scheduler_type           = "cosine",
  155.     fp16                        = not bf16_ok(),
  156.     bf16                        = bf16_ok(),
  157.     logging_steps               = 10,
  158.     save_strategy               = "epoch",
  159. )
  160.  
  161. trainer = Trainer(
  162.     model           = model,
  163.     args            = args,
  164.     train_dataset   = tokenized,
  165.     data_collator   = collator,
  166.     tokenizer       = tokenizer
  167. )
  168.  
  169. # ─── 7. TRAIN! ───────────────────────────────────────────────────────────────
  170. trainer.train()
  171.  
  172. # ─── 8. SAVE ────────────────────────────────────────────────────────────────
  173. os.makedirs(OUTPUT_DIR, exist_ok=True)
  174. model.save_pretrained(OUTPUT_DIR, safe_serialization=True)
  175. tokenizer.save_pretrained(OUTPUT_DIR)
  176. print("✓ LoRA saved to", os.path.join(OUTPUT_DIR, "adapter_model.safetensors"))
  177.  
Advertisement
Add Comment
Please, Sign In to add comment