Advertisement
Guest User

Untitled

a guest
Mar 18th, 2025
325
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.94 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. """
  5. DeepSpeed ZeRO for memory-efficient training (stage 3).
  6. 4-way experiment toggling:
  7.   --lambda_schedule [true|false]
  8.   --lambda_warmup   <int>
  9.   --should_rms      [true|false]
  10. """
  11.  
  12. import argparse
  13. import sys
  14. import time
  15. import math
  16. import torch
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19.  
  20. # Hugging Face Transformers
  21. from transformers import (
  22.     AutoModelForCausalLM,
  23.     AutoTokenizer,
  24.     Trainer,
  25.     TrainingArguments,
  26. )
  27.  
  28. from transformers.trainer_utils import TrainOutput
  29. from torch.optim import AdamW
  30.  
  31. import datasets
  32. import mmfreelm.ops.fusedbitnet as fuse  # from matmulfreellm (fusedbitnet module)
  33.  
  34.  
  35. _current_lambda = 1.0    
  36. _global_should_rms = True # If True, enable RMS in BitLinear
  37.  
  38. def get_current_lambda():
  39.     """Return the current global lambda."""
  40.     return _current_lambda
  41.  
  42.  
  43. # ------------------------------------------------------------------------------
  44. # Custom BitLinear Replacement
  45. # ------------------------------------------------------------------------------
  46. def replace_linear_with_fusedbit(model):
  47.     """
  48.    Recursively replace all nn.Linear modules with fuse.BitLinear modules.
  49.    We pass two extra arguments not present in default BitLinear (should_rms would be True and lambda would be 1.0):
  50.      - lambda_   = get_current_lambda()  (global function)
  51.      - should_rms= True
  52.    """
  53.     for name, module in model.named_modules():
  54.         if isinstance(module, nn.Linear):
  55.             fusedbit_layer = fuse.BitLinear(
  56.                 in_features=module.in_features,
  57.                 out_features=module.out_features,
  58.                 lambda_=get_current_lambda,
  59.                 should_rms=True,
  60.                 bias=(module.bias is not None)
  61.             )
  62.             #fusedbit_layer.to(module.weight.device)
  63.  
  64.             # Copy existing weights/bias
  65.             with torch.no_grad():
  66.                 fusedbit_layer.weight.copy_(module.weight)
  67.                 if module.bias is not None:
  68.                     fusedbit_layer.bias.copy_(module.bias)
  69.  
  70.             # Replace the original module in its parent
  71.             parent_path = name.rsplit('.', 1)
  72.             if len(parent_path) == 1:
  73.                 setattr(model, parent_path[0], fusedbit_layer)
  74.             else:
  75.                 parent_module_name, child_name = parent_path
  76.                 parent_module = dict(model.named_modules())[parent_module_name]
  77.                 setattr(parent_module, child_name, fusedbit_layer)
  78.  
  79.     return model
  80.  
  81.  
  82. def build_position_ids(input_ids):
  83.     """
  84.    input_ids: [batch_size, seq_length] (torch.LongTensor)
  85.    returns: position_ids: [batch_size, seq_length]
  86.    """
  87.     batch_size, seq_length = input_ids.shape
  88.     return torch.arange(seq_length, dtype=torch.long, device=input_ids.device)\
  89.                 .unsqueeze(0).expand(batch_size, seq_length)
  90.  
  91. class ChatTemplateCollator:
  92.     def __init__(self, tokenizer, max_length=512):
  93.         self.tokenizer = tokenizer
  94.         self.max_length = max_length
  95.  
  96.     def __call__(self, examples):
  97.         batch_input_ids = []
  98.         batch_attention_masks = []
  99.  
  100.         for ex in examples:
  101.             # Build a 'chat' array recognized by apply_chat_template.
  102.             # This depends on your custom approach to chat templates.
  103.             # Adjust as needed for your data structure.
  104.             chat = []
  105.             for c in ex["conversations"]:
  106.                 chat.append({"role": c["from"], "content": c["value"]})
  107.  
  108.             tokenized = self.tokenizer.apply_chat_template(
  109.                 chat,
  110.                 tokenize=True,
  111.                 add_generation_prompt=False,
  112.                 return_dict=True,
  113.                 return_tensors="pt",
  114.             )
  115.  
  116.             # Squeeze out the extra batch dimension
  117.             input_ids = tokenized["input_ids"].squeeze(0)         # [seq_len]
  118.             attention_mask = tokenized["attention_mask"].squeeze(0)  # [seq_len]
  119.  
  120.             # Apply length truncation if needed
  121.             if input_ids.size(0) > self.max_length:
  122.                 input_ids = input_ids[: self.max_length]
  123.                 attention_mask = attention_mask[: self.max_length]
  124.  
  125.             batch_input_ids.append(input_ids)
  126.             batch_attention_masks.append(attention_mask)
  127.  
  128.         # Pad the entire batch
  129.         padded_input_ids = torch.nn.utils.rnn.pad_sequence(
  130.             batch_input_ids,
  131.             batch_first=True,
  132.             padding_value=self.tokenizer.pad_token_id,
  133.         ).clone()
  134.  
  135.         position_ids = build_position_ids(padded_input_ids).clone()
  136.  
  137.         padded_attention_masks = torch.nn.utils.rnn.pad_sequence(
  138.             batch_attention_masks,
  139.             batch_first=True,
  140.             padding_value=0,
  141.         )
  142.  
  143.         # For causal LM, labels = input_ids
  144.         labels = padded_input_ids.clone()
  145.  
  146.         return {
  147.             "input_ids": padded_input_ids,
  148.             "attention_mask": padded_attention_masks,
  149.             "labels": labels,
  150.             "position_ids": position_ids,
  151.         }
  152.  
  153.  
  154. def main():
  155.  
  156.     global _current_lambda, _global_should_rms
  157.     _current_lambda = 1.0
  158.    
  159.     base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
  160.     print(f"Loading model from {base_model_name} ...")
  161.  
  162.     student = AutoModelForCausalLM.from_pretrained(
  163.         base_model_name,
  164.         torch_dtype=torch.bfloat16,
  165.         low_cpu_mem_usage=True,
  166.         trust_remote_code=True,
  167.     )
  168.    
  169.     student = replace_linear_with_fusedbit(student)
  170.  
  171.     tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
  172.     if tokenizer.pad_token is None:
  173.         tokenizer.pad_token = tokenizer.eos_token
  174.    
  175.     tokenizer.model_max_length = 8192+4096
  176.  
  177.     print("Loading dataset open-thoughts/OpenThoughts-114k...")
  178.     train_dataset = datasets.load_dataset("open-thoughts/OpenThoughts-114k", split="train")
  179.    
  180.     data_collator = ChatTemplateCollator(tokenizer, max_length=8192+4096)  # or 16384
  181.  
  182.     # Name
  183.     exp_name = []
  184.     exp_name.append(base_model_name.replace("/", "-"))
  185.     exp_name.append("LambdaSchedule")#if args.lambda_schedule else "NoLambda")
  186.     exp_name.append("RMS")# if args.should_rms else "NoRMS")
  187.     output_dir = "_".join(exp_name)
  188.  
  189.     # Deepspeed config
  190.     import json
  191.     ds_config = {
  192.         "bf16": {
  193.             "enabled": True
  194.         },
  195.         "zero_optimization": {
  196.             "stage": 3,
  197.             "overlap_comm": True,
  198.             "contiguous_gradients": True,
  199.             "reduce_scatter": True,
  200.             "reduce_bucket_size": 104857600,
  201.             "allgather_partitions": True,
  202.             "allgather_bucket_size": 104857600
  203.         },
  204.         "gradient_clipping": 1.0,
  205.         "train_micro_batch_size_per_gpu": 1,
  206.         "gradient_accumulation_steps": 4,
  207.         "save_only_model": True,
  208.         "save_steps": 500,
  209.     }
  210.     ds_config_path = "temp_ds_config.json"
  211.     with open(ds_config_path, "w") as f:
  212.         json.dump(ds_config, f)
  213.  
  214.     # TrainingArguments
  215.     training_args = TrainingArguments(
  216.         output_dir=output_dir,
  217.         overwrite_output_dir=True,
  218.         remove_unused_columns=False,
  219.         max_steps=4000,
  220.         per_device_train_batch_size=1,
  221.         save_steps=1000,
  222.         logging_steps=10,
  223.         evaluation_strategy="no",
  224.         bf16=True,  # matches ds_config
  225.         gradient_checkpointing=True,
  226.         gradient_accumulation_steps=4,
  227.         deepspeed=ds_config_path,
  228.     )
  229.  
  230.     # Create Trainer
  231.     trainer = Trainer(
  232.         model=student,
  233.         args=training_args,
  234.         train_dataset=train_dataset,
  235.         tokenizer=tokenizer,
  236.         data_collator=data_collator,
  237.     )
  238.  
  239.     print(f"Starting training for {output_dir} with DeepSpeed ZeRO ...")
  240.     trainer.train()
  241.  
  242.     # Save final model
  243.     final_save_dir = "final_" + output_dir
  244.     trainer.save_model(final_save_dir)
  245.     print(f"✅ Done. Model saved to {final_save_dir}.")
  246.  
  247.  
  248.  
  249. if __name__ == "__main__":
  250.     main()
  251.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement