Guest User

Untitled

a guest
Feb 15th, 2025
224
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 24.46 KB | None | 0 0
  1. """
  2. LlamaThink-8b-Instruct Finetuning Process
  3.  
  4. I recently created LlamaThink-8b-Instruct Full Instruct model: https://huggingface.co/DavidBrowne17/LlamaThink-8B-instruct
  5.  
  6. GGUF: LlamaThink-8b-Instruct-GGUF: https://huggingface.co/DavidBrowne17/LlamaThink-8B-instruct-GGUF
  7.  
  8. and a few of you were curious as to how I made it, here is the process to finetune a model with GRPO reinforcement learning.
  9.  
  10. So our goal is to make a thinker model, its super easy, first we need a dataset. Here is a script for llama cpp python to create a dataset.
  11. """
  12.  
  13.  
  14. import json
  15. import gc
  16. import random
  17. import re
  18. from llama_cpp import Llama
  19. import textwrap
  20.  
  21. MODEL_PATHS = [
  22.     "YOUR MODEL GGUF HERE"
  23. ]
  24.  
  25. OUTPUT_FILE = "./enhanced_simple_dataset.jsonl"
  26.  
  27. NUM_CONVERSATIONS = 5000
  28. TURNS_PER_CONVO = 1
  29. MAX_TOKENS = 100
  30.  
  31. STOP_TOKENS = [
  32.     "</s>", "<|endoftext|>", "<<USR>>", "<</USR>>", "<</SYS>>", "<</USER>>",
  33.     "<</ASSISTANT>>", "<|eot_id|>", "<|im_end|>", "user:", "User:", "user :",
  34.     "User :", "[assistant]", "[[assistant]]", "[user]", "[[user]]",
  35.     "[/assistant]", "[/user]", "[\\assistant]"
  36. ]
  37.  
  38. USER_INSTRUCTION = (
  39.     "You are engaging in a conversation with an AI designed for deep reasoning and structured thinking. "
  40.     "Ask questions naturally while expecting insightful, multi-layered responses. "
  41.     "Ask a unique, relevant question. "
  42.     "Keep messages clear and concise. Respond only with the Question, nothing else."
  43. )
  44.  
  45. INSTRUCTIONS = {
  46.     "system_prompt": textwrap.dedent(""" Generate a system prompt for an AI to follow. This is a prompt for how the AI should behave, e.g., You are a chatbot, assistant, maths teacher, etc. It should not be instructions for a specific task. Do not add any explanations, headers, or formatting. Only output the system prompt text. """).strip(),
  47.  
  48.     "thinking": (
  49.         "You are an AI designed to think deeply about the conversation topic. "
  50.         "This is your internal thought process which is not visible to the user. "
  51.         "Explain to yourself how you figure out the answer. "
  52.         "Consider the user's question carefully, analyze the context, and formulate a coherent response strategy. "
  53.         "Ensure your thought process is logical and well-structured. Do not generate any headers."
  54.     ),
  55.  
  56.     "final": (
  57.         "You are the final reviewer ensuring the response meets high standards of quality and insight. "
  58.         "Your goal is to:\n"
  59.         "1. Maximize logical depth and engagement.\n"
  60.         "2. Ensure the response is precise, well-reasoned, and helpful.\n"
  61.         "3. Strengthen structured argumentation and clarity.\n"
  62.         "4. Maintain a professional and well-organized tone.\n"
  63.         "In your final response, reference the user-provided system prompt to ensure consistency and relevance. "
  64.         "Be concise and give the final answer."
  65.     )
  66. }
  67.  
  68. def load_model(path):
  69.     """Loads a single model."""
  70.     try:
  71.         return Llama(model_path=path, n_ctx=16000, n_gpu_layers=-1, chat_format="llama-3")
  72.     except Exception as e:
  73.         print(f"Failed to load model {path}: {e}")
  74.         return None
  75.  
  76. def call_model(llm, messages):
  77.     """Calls the model using chat completion API and retries on failure."""
  78.     attempt = 0
  79.     while True:
  80.         attempt += 1
  81.         try:
  82.             result = llm.create_chat_completion(
  83.                 messages=messages,
  84.                 max_tokens=MAX_TOKENS,
  85.                 temperature=random.uniform(1.4, 1.7),
  86.                 top_k=random.choice([250, 350]),
  87.                 top_p=random.uniform(0.85, 0.95),
  88.                 seed=random.randint(1, 900000000),
  89.                 stop=STOP_TOKENS
  90.             )
  91.             response_text = result["choices"][0]["message"]["content"].strip()
  92.             if response_text:
  93.                 return response_text
  94.             else:
  95.                 print(f"Attempt {attempt}: Empty response. Retrying...")
  96.         except ValueError as e:
  97.             print(f"Attempt {attempt}: Model call error: {e}. Retrying...")
  98.         except KeyboardInterrupt:
  99.             print("\nManual interruption detected. Exiting retry loop.")
  100.             return "Error: Retry loop interrupted by user."
  101.         except Exception as e:
  102.             print(f"Unexpected error on attempt {attempt}: {e}. Retrying...")
  103.  
  104. def generate_system_prompt(llm):
  105.     messages = [{"role": "system", "content": INSTRUCTIONS["system_prompt"]}]
  106.     return call_model(llm, messages)
  107.  
  108. def generate_user_message(llm, system_prompt):
  109.     messages = [
  110.         {"role": "system", "content": system_prompt},
  111.         {"role": "user", "content": USER_INSTRUCTION}
  112.     ]
  113.     return call_model(llm, messages)
  114.  
  115. def trim_to_last_complete_sentence(text):
  116.     """Trims text to the last complete sentence."""
  117.     matches = list(re.finditer(r'[.!?]', text))
  118.     return text[:matches[-1].end()] if matches else text
  119.  
  120. def generate_response(llm, conversation_history, system_prompt):
  121.     thinking = call_model(llm, [
  122.         {"role": "system", "content": system_prompt},
  123.         {"role": "user", "content": INSTRUCTIONS["thinking"]}
  124.     ])
  125.  
  126.     final_response = call_model(llm, [
  127.         {"role": "system", "content": system_prompt},
  128.         {"role": "user", "content": INSTRUCTIONS["final"]}
  129.     ])
  130.  
  131.     return f"<thinking>{trim_to_last_complete_sentence(thinking)}</thinking>\n\n<answer>{trim_to_last_complete_sentence(final_response)}</answer>"
  132.  
  133. def format_conversation(conversation):
  134.     return "\n".join(f"{entry['role']}: {entry['content']}" for entry in conversation)
  135.  
  136. def generate_conversation(llm):
  137.     conversation = []
  138.     system_prompt = generate_system_prompt(llm)
  139.  
  140.     for _ in range(TURNS_PER_CONVO):
  141.         user_message_text = generate_user_message(llm, system_prompt)
  142.         conversation.append({"role": "user", "content": user_message_text})
  143.  
  144.         conv_history_str = format_conversation(conversation)
  145.         assistant_message_text = generate_response(llm, conv_history_str, system_prompt)
  146.         conversation.append({"role": "assistant", "content": assistant_message_text})
  147.  
  148.     return system_prompt, conversation
  149.  
  150. def validate_json(data):
  151.     """Ensures JSON is valid before writing."""
  152.     try:
  153.         json.loads(json.dumps(data))
  154.         return True
  155.     except json.JSONDecodeError as e:
  156.         print(f"Invalid JSON detected: {e}")
  157.         return False
  158.  
  159. def main():
  160.     llm = load_model(MODEL_PATHS[0])
  161.     if not llm:
  162.         print("Failed to load the model. Exiting.")
  163.         return
  164.  
  165.     with open(OUTPUT_FILE, "a", encoding="utf-8") as out_f:
  166.         for convo_idx in range(NUM_CONVERSATIONS):
  167.             system_prompt, conversation = generate_conversation(llm)
  168.  
  169.             json_output = {
  170.                 "instruction": system_prompt.strip(),
  171.                 "conversation": conversation
  172.             }
  173.  
  174.             if validate_json(json_output):
  175.                 json_string = json.dumps(json_output, ensure_ascii=False)
  176.                 out_f.write(json_string + "\n")
  177.             else:
  178.                 print(f"Skipping malformed JSON for conversation {convo_idx}")
  179.  
  180.             if convo_idx % 100 == 0:
  181.                 print(f"Wrote conversation {convo_idx}/{NUM_CONVERSATIONS}")
  182.  
  183.     del llm
  184.     gc.collect()
  185.  
  186.     print(f"Dataset complete: {OUTPUT_FILE}")
  187.  
  188. if __name__ == "__main__":
  189.     main()
  190.  
  191. ##########################################################################################################################
  192. # grpo.py
  193.  
  194. """
  195. I set the limit to 5000 but we really only need about 300 results to finetune our model. I highly recommend changing the prompts slightly as you get more useful data, to get a more diverse dataset, This will improve your final results. Tell it to be a mathematician, historian etc. and to ask complex advanced questions.
  196.  
  197. Once the dataset is ready, install unsloth. Once your install is done you can create a new file called grpo.py which contains the following code, once the dataset is ready, place it in the same directory as the grpo.py file in the unsloth folder.
  198. """
  199.  
  200. import sys
  201. import os
  202. import re
  203. import torch
  204. from typing import List
  205. from sentence_transformers import SentenceTransformer
  206. import numpy as np
  207.  
  208. embedder = SentenceTransformer("all-MiniLM-L6-v2")
  209. os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
  210.  
  211. if sys.platform == "win32":
  212.     import types
  213.     resource = types.ModuleType("resource")
  214.     resource.getrlimit = lambda resource_id: (0, 0)
  215.     resource.setrlimit = lambda resource_id, limits: None
  216.     sys.modules["resource"] = resource
  217.  
  218. from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
  219. PatchFastRL("GRPO", FastLanguageModel)
  220. from datasets import load_dataset
  221. from trl import GRPOConfig, GRPOTrainer
  222. from transformers import AutoModelForCausalLM, AutoTokenizer
  223. from peft import LoraConfig, get_peft_model, PeftModel
  224.  
  225. # Configuration
  226. MAX_SEQ_LENGTH = 256
  227. LORA_RANK = 16
  228. BASE_MODEL_NAME = "unsloth/Meta-Llama-3.1-8B-instruct"
  229. DATASET_PATH = "enhanced_simple_dataset.jsonl"
  230. ADAPTER_SAVE_PATH = "grpo_adapter"
  231. MERGED_MODEL_PATH = "merged_grpo_full"
  232. SYSTEM_PROMPT = """ Respond in the following format: <thinking> ... </thinking> <answer> ... </answer> The thinking and answer portions should be no more than 100 tokens each. """
  233.  
  234. def format_dataset_entry(example):
  235.     """Format dataset entries for GRPO training."""
  236.     system_prompt = example.get("instruction", "")
  237.     conversation = example.get("conversation", [])
  238.  
  239.     messages = [{"role": "system", "content": system_prompt + SYSTEM_PROMPT}]
  240.  
  241.     if conversation and conversation[-1].get("role") == "assistant":
  242.         for turn in conversation[:-1]:
  243.             messages.append(turn)
  244.         answer = conversation[-1].get("content", "")
  245.     else:
  246.         for turn in conversation:
  247.             messages.append(turn)
  248.         answer = ""
  249.  
  250.     return {"prompt": messages, "answer": answer}
  251.  
  252. def extract_xml_answer(text: str) -> str:
  253.     answer = text.split("<answer>")[-1]
  254.     answer = answer.split("</answer>")[0]
  255.     return answer.strip()
  256.  
  257. def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
  258.     responses = [completion[0]['content'] for completion in completions]
  259.     q = prompts[0][-1]['content']
  260.     extracted_responses = [extract_xml_answer(r) for r in responses]
  261.  
  262.     print('-' * 20,
  263.           f"Question:\n{q}",
  264.           f"\nAnswer:\n{answer[0]}",
  265.           f"\nResponse:\n{responses[0]}",
  266.           f"\nExtracted:\n{extracted_responses[0]}")
  267.  
  268.     # Compute embeddings and cosine similarity
  269.     answer_embedding = embedder.encode(answer, convert_to_numpy=True)
  270.     response_embeddings = embedder.encode(extracted_responses, convert_to_numpy=True)
  271.  
  272.     similarities = [np.dot(r, answer_embedding) / (np.linalg.norm(r) * np.linalg.norm(answer_embedding))
  273.                     for r in response_embeddings]
  274.  
  275.     # Convert similarity to reward (scaled 0-2 range)
  276.     return [max(0.0, min(2.0, s * 2)) for s in similarities]
  277.  
  278. def int_reward_func(completions, **kwargs) -> list[float]:
  279.     responses = [completion[0]['content'] for completion in completions]
  280.     extracted_responses = [extract_xml_answer(r) for r in responses]
  281.     return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
  282.  
  283. def strict_format_reward_func(completions, *kwargs) -> list[float]:
  284.     pattern = r"<thinking>\n.?\n</thinking>\n<answer>\n.*?\n</answer>\n$"
  285.     responses = [completion[0]["content"] for completion in completions]
  286.     matches = [re.match(pattern, r) for r in responses]
  287.     return [0.5 if match else 0.0 for match in matches]
  288.  
  289. def soft_format_reward_func(completions, *kwargs) -> list[float]:
  290.     pattern = r"<thinking>.?</thinking>\s<answer>.?</answer>"
  291.     responses = [completion[0]["content"] for completion in completions]
  292.     matches = [re.match(pattern, r) for r in responses]
  293.     return [0.5 if match else 0.0 for match in matches]
  294.  
  295. def count_xml(text) -> float:
  296.     count = 0.0
  297.     if text.count("<thinking>\n") == 1:
  298.         count += 0.125
  299.     if text.count("\n</thinking>\n") == 1:
  300.         count += 0.125
  301.     if text.count("\n<answer>\n") == 1:
  302.         count += 0.125
  303.     count -= len(text.split("\n</answer>\n")[-1]) * 0.001
  304.     if text.count("\n</answer>") == 1:
  305.         count += 0.125
  306.     count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
  307.     return count
  308.  
  309. def xmlcount_reward_func(completions, **kwargs) -> list[float]:
  310.     contents = [completion[0]["content"] for completion in completions]
  311.     return [count_xml(c) for c in contents]
  312.  
  313. def main():
  314.     print("Loading model and tokenizer...")
  315.     model, tokenizer = FastLanguageModel.from_pretrained(
  316.         model_name=BASE_MODEL_NAME,
  317.         max_seq_length=MAX_SEQ_LENGTH,
  318.         load_in_4bit=True,
  319.         fast_inference=False,
  320.         max_lora_rank=LORA_RANK,
  321.         gpu_memory_utilization=0.9,
  322.         device_map={"": torch.cuda.current_device()}
  323.     )
  324.  
  325.     print("Applying GRPO adapter...")
  326.  
  327.     lora_config = LoraConfig(
  328.         r=16,
  329.         lora_alpha=16,
  330.         target_modules=[
  331.             "q_proj", "k_proj", "v_proj", "o_proj",
  332.             "gate_proj", "up_proj", "down_proj", "embed_tokens", "lm_head"
  333.         ],
  334.         lora_dropout=0.05,
  335.         bias="none",
  336.         task_type="CAUSAL_LM",
  337.         inference_mode=False
  338.     )
  339.  
  340.     print("Applying QLoRA to the base model.")
  341.     model = get_peft_model(model, lora_config)
  342.     print("Loading and processing dataset...")
  343.     raw_dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
  344.     formatted_dataset = raw_dataset.map(format_dataset_entry)
  345.  
  346.     print("Configuring training...")
  347.     training_args = GRPOConfig(
  348.         use_vllm = False,
  349.         learning_rate = 5e-6,
  350.         adam_beta1 = 0.9,
  351.         adam_beta2 = 0.99,
  352.         weight_decay = 0.1,
  353.         warmup_ratio = 0.1,
  354.         lr_scheduler_type = "cosine",
  355.         optim = "paged_adamw_8bit",
  356.         logging_steps = 1,
  357.         bf16 = is_bfloat16_supported(),
  358.         fp16 = not is_bfloat16_supported(),
  359.         per_device_train_batch_size = 1
  360.         gradient_accumulation_steps = 1,
  361.         num_generations = 6, # Decrease if out of memory
  362.         max_prompt_length = 256,
  363.         max_completion_length = 250,
  364.         max_steps = 250,
  365.         save_steps = 10,
  366.         max_grad_norm = 0.1,
  367.         report_to = "none",
  368.         output_dir = "outputs",
  369.     )
  370.  
  371.     print("Initializing trainer...")
  372.     trainer = GRPOTrainer(
  373.         model=model,
  374.         processing_class=tokenizer,
  375.         reward_funcs=[
  376.             xmlcount_reward_func,
  377.             soft_format_reward_func,
  378.             strict_format_reward_func,
  379.             int_reward_func,
  380.             correctness_reward_func,
  381.         ],
  382.         args=training_args,
  383.         train_dataset=formatted_dataset,
  384.     )
  385.  
  386.     print("Starting training...")
  387.     trainer.train()
  388.  
  389.     print(f"Saving GRPO adapter to {ADAPTER_SAVE_PATH}")
  390.     model.save_pretrained(ADAPTER_SAVE_PATH)
  391.     tokenizer.save_pretrained(ADAPTER_SAVE_PATH)
  392.  
  393.     print("Loading base model for merging...")
  394.     base_model = AutoModelForCausalLM.from_pretrained(
  395.         BASE_MODEL_NAME,
  396.         torch_dtype=torch.float16,
  397.         device_map={"": torch.cuda.current_device()}
  398.     )
  399.     base_model.config.pad_token_id = tokenizer.pad_token_id
  400.  
  401.     print("Merging GRPO adapter...")
  402.     grpo_model = PeftModel.from_pretrained(base_model, ADAPTER_SAVE_PATH)
  403.     merged_model = grpo_model.merge_and_unload()
  404.  
  405.     print(f"Saving merged model to {MERGED_MODEL_PATH}")
  406.     merged_model.save_pretrained(MERGED_MODEL_PATH)
  407.     tokenizer.save_pretrained(MERGED_MODEL_PATH)
  408.  
  409.     print("Process completed successfully!")
  410.  
  411. if __name__ == "main":
  412.     main()
  413.  
  414.  ##########################################################################################################################
  415.  #
  416. """
  417. We are loading and finetuning the model in 4 bit, but saving the adapter in the full model, this will significantly speed up the training time. For the most part your dataset doesnt need advanced coding info, we just need it to be simple and fit the format well so the model can learn to think. When this is finished you should have a completed finetuned thinking model. This code can be used for smaller models like Llama-3b. Have fun machine learning!
  418.  
  419. If you crash mid training you can load your latest checkpoint
  420. """
  421.    
  422. import sys
  423. import os
  424. import re
  425. import torch
  426. from typing import List
  427.  
  428. if sys.platform == "win32":
  429.     import types
  430.     resource = types.ModuleType("resource")
  431.     resource.getrlimit = lambda resource_id: (0, 0)
  432.     resource.setrlimit = lambda resource_id, limits: None
  433.     sys.modules["resource"] = resource
  434.  
  435. from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
  436. PatchFastRL("GRPO", FastLanguageModel)
  437. from datasets import load_dataset
  438. from trl import GRPOConfig, GRPOTrainer
  439. from transformers import AutoModelForCausalLM, AutoTokenizer
  440. from peft import LoraConfig, get_peft_model, PeftModel
  441. from sentence_transformers import SentenceTransformer
  442. import numpy as np
  443.  
  444. embedder = SentenceTransformer("all-MiniLM-L6-v2")
  445. MAX_SEQ_LENGTH = 512
  446. LORA_RANK = 32
  447. BASE_MODEL_NAME = "unsloth/meta-Llama-3.1-8B-instruct"
  448. DATASET_PATH = "enhanced_dataset.jsonl"
  449. ADAPTER_SAVE_PATH = "grpo_adapter"
  450. MERGED_MODEL_PATH = "merged_grpo_full"
  451. CHECKPOINT_PATH = "YOUR_LATEST_CHECKPOINT"
  452. SYSTEM_PROMPT = """
  453. Respond in the following format:
  454. <thinking>
  455. ...
  456. </thinking>
  457. <answer>
  458. ...
  459. </answer>
  460. """
  461.  
  462. def format_dataset_entry(example):
  463.     """Format dataset entries for GRPO training."""
  464.     system_prompt = example.get("instruction", "")
  465.     conversation = example.get("conversation", [])
  466.    
  467.     messages = [{"role": "system", "content": system_prompt + SYSTEM_PROMPT}]
  468.    
  469.     if conversation and conversation[-1].get("role") == "assistant":
  470.         for turn in conversation[:-1]:
  471.             messages.append(turn)
  472.         answer = conversation[-1].get("content", "")
  473.     else:
  474.         for turn in conversation:
  475.             messages.append(turn)
  476.         answer = ""
  477.        
  478.     return {"prompt": messages, "answer": answer}
  479.  
  480. def extract_xml_answer(text: str) -> str:
  481.     answer = text.split("<answer>")[-1]
  482.     answer = answer.split("</answer>")[0]
  483.     return answer.strip()
  484.  
  485. def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
  486.     responses = [completion[0]['content'] for completion in completions]
  487.     q = prompts[0][-1]['content']
  488.     extracted_responses = [extract_xml_answer(r) for r in responses]
  489.  
  490.     print('-' * 20,
  491.           f"Question:\n{q}",
  492.           f"\nAnswer:\n{answer[0]}",
  493.           f"\nResponse:\n{responses[0]}",
  494.           f"\nExtracted:\n{extracted_responses[0]}")
  495.  
  496.     # Compute embeddings and cosine similarity
  497.     answer_embedding = embedder.encode(answer, convert_to_numpy=True)
  498.     response_embeddings = embedder.encode(extracted_responses, convert_to_numpy=True)
  499.  
  500.     similarities = [np.dot(r, answer_embedding) / (np.linalg.norm(r) * np.linalg.norm(answer_embedding))
  501.                     for r in response_embeddings]
  502.  
  503.     # Convert similarity to reward (scaled 0-2 range)
  504.     return [max(0.0, min(2.0, s * 2)) for s in similarities]
  505.  
  506. def int_reward_func(completions, **kwargs) -> list[float]:
  507.     responses = [completion[0]['content'] for completion in completions]
  508.     extracted_responses = [extract_xml_answer(r) for r in responses]
  509.     return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
  510.  
  511. def strict_format_reward_func(completions, **kwargs) -> list[float]:
  512.     pattern = r"^<thinking>\n.*?\n</thinking>\n<answer>\n.*?\n</answer>\n$"
  513.     responses = [completion[0]["content"] for completion in completions]
  514.     matches = [re.match(pattern, r) for r in responses]
  515.     return [0.5 if match else 0.0 for match in matches]
  516.  
  517. def soft_format_reward_func(completions, **kwargs) -> list[float]:
  518.     pattern = r"<thinking>.*?</thinking>\s*<answer>.*?</answer>"
  519.     responses = [completion[0]["content"] for completion in completions]
  520.     matches = [re.match(pattern, r) for r in responses]
  521.     return [0.5 if match else 0.0 for match in matches]
  522.  
  523. def count_xml(text) -> float:
  524.     count = 0.0
  525.     if text.count("<thinking>\n") == 1:
  526.         count += 0.125
  527.     if text.count("\n</thinking>\n") == 1:
  528.         count += 0.125
  529.     if text.count("\n<answer>\n") == 1:
  530.         count += 0.125
  531.         count -= len(text.split("\n</answer>\n")[-1])*0.001
  532.     if text.count("\n</answer>") == 1:
  533.         count += 0.125
  534.         count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
  535.     return count
  536.  
  537. def xmlcount_reward_func(completions, **kwargs) -> list[float]:
  538.     contents = [completion[0]["content"] for completion in completions]
  539.     return [count_xml(c) for c in contents]
  540.  
  541. def main():
  542.     print("Loading model and tokenizer...")
  543.     model, tokenizer = FastLanguageModel.from_pretrained(
  544.         model_name=BASE_MODEL_NAME,
  545.         max_seq_length=MAX_SEQ_LENGTH,
  546.         load_in_4bit=True,
  547.         fast_inference=False,
  548.         max_lora_rank=LORA_RANK,
  549.         gpu_memory_utilization=0.9,
  550.         device_map={"": torch.cuda.current_device()}
  551.     )
  552.  
  553.     print("Applying GRPO adapter...")
  554.     lora_config = LoraConfig(
  555.         r=16,
  556.         lora_alpha=16,
  557.         target_modules=[
  558.             "q_proj", "k_proj", "v_proj", "o_proj",
  559.             "gate_proj", "up_proj", "down_proj", "embed_tokens", "lm_head"
  560.         ],
  561.         lora_dropout=0.05,
  562.         bias="none",
  563.         task_type="CAUSAL_LM",
  564.         inference_mode=False
  565.     )
  566.  
  567.     print("Applying QLoRA to the base model.")
  568.     model = get_peft_model(model, lora_config)
  569.  
  570.     print("Loading and processing dataset...")
  571.     raw_dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
  572.     formatted_dataset = raw_dataset.map(format_dataset_entry)
  573.  
  574.     print("Configuring training...")
  575.     training_args = GRPOConfig(
  576.         use_vllm = False,
  577.         learning_rate = 5e-6,
  578.         adam_beta1 = 0.9,
  579.         adam_beta2 = 0.99,
  580.         weight_decay = 0.1,
  581.         warmup_ratio = 0.1,
  582.         lr_scheduler_type = "cosine",
  583.         optim = "paged_adamw_8bit",
  584.         logging_steps = 1,
  585.         bf16 = is_bfloat16_supported(),
  586.         fp16 = not is_bfloat16_supported(),
  587.         per_device_train_batch_size = 1,
  588.         gradient_accumulation_steps = 1,
  589.         num_generations = 6,
  590.         max_prompt_length = 256,
  591.         max_completion_length = 250,
  592.         num_train_epochs = 1,
  593.         max_steps = 250,
  594.         save_steps = 10,
  595.         max_grad_norm = 0.1,
  596.         report_to = "none",
  597.         output_dir = "outputs",
  598.     )
  599.  
  600.     print("Initializing trainer...")
  601.     trainer = GRPOTrainer(
  602.         model=model,
  603.         processing_class=tokenizer,
  604.         reward_funcs=[
  605.             xmlcount_reward_func,
  606.             soft_format_reward_func,
  607.             strict_format_reward_func,
  608.             int_reward_func,
  609.             correctness_reward_func,
  610.         ],
  611.         args=training_args,
  612.         train_dataset=formatted_dataset,
  613.     )
  614.  
  615.     print("Starting training...")
  616.     try:
  617.         if os.path.exists(CHECKPOINT_PATH):
  618.             print(f"Resuming training from checkpoint: {CHECKPOINT_PATH}")
  619.             trainer.train(resume_from_checkpoint=CHECKPOINT_PATH)
  620.         else:
  621.             print("No checkpoint found; starting training from scratch...")
  622.             trainer.train()
  623.  
  624.         # Save the adapter
  625.         print(f"Saving GRPO adapter to {ADAPTER_SAVE_PATH}")
  626.         if not os.path.exists(ADAPTER_SAVE_PATH):
  627.             os.makedirs(ADAPTER_SAVE_PATH)
  628.         model.save_pretrained(ADAPTER_SAVE_PATH)
  629.         tokenizer.save_pretrained(ADAPTER_SAVE_PATH)
  630.  
  631.     except Exception as e:
  632.         print(f"Error during training or saving: {str(e)}")
  633.         raise
  634.  
  635.     try:
  636.         print("Loading base model in full precision...")
  637.         base_model = AutoModelForCausalLM.from_pretrained(
  638.             BASE_MODEL_NAME,
  639.             torch_dtype=torch.float16,
  640.             device_map={"": torch.cuda.current_device()}
  641.         )
  642.  
  643.         base_model.config.pad_token_id = tokenizer.pad_token_id
  644.  
  645.         print("Loading and merging GRPO adapter...")
  646.         grpo_model = PeftModel.from_pretrained(base_model, ADAPTER_SAVE_PATH)
  647.         merged_model = grpo_model.merge_and_unload()
  648.  
  649.         if not os.path.exists(MERGED_MODEL_PATH):
  650.             os.makedirs(MERGED_MODEL_PATH)
  651.  
  652.         print(f"Saving merged model to {MERGED_MODEL_PATH}")
  653.         merged_model.save_pretrained(MERGED_MODEL_PATH)
  654.         tokenizer.save_pretrained(MERGED_MODEL_PATH)
  655.  
  656.         print("Process completed successfully!")
  657.  
  658.     except Exception as e:
  659.         print(f"Error during model merging: {str(e)}")
  660.         raise
  661.  
  662. if __name__ == "__main__":
  663.     main()
Advertisement
Add Comment
Please, Sign In to add comment