Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- DeepSpeed ZeRO for memory-efficient training (stage 3).
- 4-way experiment toggling:
- --lambda_schedule [true|false]
- --lambda_warmup <int>
- --should_rms [true|false]
- """
- import argparse
- import sys
- import time
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- # Hugging Face Transformers
- from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
- Trainer,
- TrainingArguments,
- )
- from transformers.trainer_utils import TrainOutput
- from torch.optim import AdamW
- import datasets
- import mmfreelm.ops.fusedbitnet as fuse # from matmulfreellm (fusedbitnet module)
- _current_lambda = 1.0
- _global_should_rms = True # If True, enable RMS in BitLinear
- def get_current_lambda():
- """Return the current global lambda."""
- return _current_lambda
- # ------------------------------------------------------------------------------
- # Custom BitLinear Replacement
- # ------------------------------------------------------------------------------
- def replace_linear_with_fusedbit(model):
- """
- Recursively replace all nn.Linear modules with fuse.BitLinear modules.
- We pass two extra arguments not present in default BitLinear (should_rms would be True and lambda would be 1.0):
- - lambda_ = get_current_lambda() (global function)
- - should_rms= True
- """
- for name, module in model.named_modules():
- if isinstance(module, nn.Linear):
- fusedbit_layer = fuse.BitLinear(
- in_features=module.in_features,
- out_features=module.out_features,
- lambda_=get_current_lambda,
- should_rms=True,
- bias=(module.bias is not None)
- )
- #fusedbit_layer.to(module.weight.device)
- # Copy existing weights/bias
- with torch.no_grad():
- fusedbit_layer.weight.copy_(module.weight)
- if module.bias is not None:
- fusedbit_layer.bias.copy_(module.bias)
- # Replace the original module in its parent
- parent_path = name.rsplit('.', 1)
- if len(parent_path) == 1:
- setattr(model, parent_path[0], fusedbit_layer)
- else:
- parent_module_name, child_name = parent_path
- parent_module = dict(model.named_modules())[parent_module_name]
- setattr(parent_module, child_name, fusedbit_layer)
- return model
- def build_position_ids(input_ids):
- """
- input_ids: [batch_size, seq_length] (torch.LongTensor)
- returns: position_ids: [batch_size, seq_length]
- """
- batch_size, seq_length = input_ids.shape
- return torch.arange(seq_length, dtype=torch.long, device=input_ids.device)\
- .unsqueeze(0).expand(batch_size, seq_length)
- class ChatTemplateCollator:
- def __init__(self, tokenizer, max_length=512):
- self.tokenizer = tokenizer
- self.max_length = max_length
- def __call__(self, examples):
- batch_input_ids = []
- batch_attention_masks = []
- for ex in examples:
- # Build a 'chat' array recognized by apply_chat_template.
- # This depends on your custom approach to chat templates.
- # Adjust as needed for your data structure.
- chat = []
- for c in ex["conversations"]:
- chat.append({"role": c["from"], "content": c["value"]})
- tokenized = self.tokenizer.apply_chat_template(
- chat,
- tokenize=True,
- add_generation_prompt=False,
- return_dict=True,
- return_tensors="pt",
- )
- # Squeeze out the extra batch dimension
- input_ids = tokenized["input_ids"].squeeze(0) # [seq_len]
- attention_mask = tokenized["attention_mask"].squeeze(0) # [seq_len]
- # Apply length truncation if needed
- if input_ids.size(0) > self.max_length:
- input_ids = input_ids[: self.max_length]
- attention_mask = attention_mask[: self.max_length]
- batch_input_ids.append(input_ids)
- batch_attention_masks.append(attention_mask)
- # Pad the entire batch
- padded_input_ids = torch.nn.utils.rnn.pad_sequence(
- batch_input_ids,
- batch_first=True,
- padding_value=self.tokenizer.pad_token_id,
- ).clone()
- position_ids = build_position_ids(padded_input_ids).clone()
- padded_attention_masks = torch.nn.utils.rnn.pad_sequence(
- batch_attention_masks,
- batch_first=True,
- padding_value=0,
- )
- # For causal LM, labels = input_ids
- labels = padded_input_ids.clone()
- return {
- "input_ids": padded_input_ids,
- "attention_mask": padded_attention_masks,
- "labels": labels,
- "position_ids": position_ids,
- }
- def main():
- global _current_lambda, _global_should_rms
- _current_lambda = 1.0
- base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
- print(f"Loading model from {base_model_name} ...")
- student = AutoModelForCausalLM.from_pretrained(
- base_model_name,
- torch_dtype=torch.bfloat16,
- low_cpu_mem_usage=True,
- trust_remote_code=True,
- )
- student = replace_linear_with_fusedbit(student)
- tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- tokenizer.model_max_length = 8192+4096
- print("Loading dataset open-thoughts/OpenThoughts-114k...")
- train_dataset = datasets.load_dataset("open-thoughts/OpenThoughts-114k", split="train")
- data_collator = ChatTemplateCollator(tokenizer, max_length=8192+4096) # or 16384
- # Name
- exp_name = []
- exp_name.append(base_model_name.replace("/", "-"))
- exp_name.append("LambdaSchedule")#if args.lambda_schedule else "NoLambda")
- exp_name.append("RMS")# if args.should_rms else "NoRMS")
- output_dir = "_".join(exp_name)
- # Deepspeed config
- import json
- ds_config = {
- "bf16": {
- "enabled": True
- },
- "zero_optimization": {
- "stage": 3,
- "overlap_comm": True,
- "contiguous_gradients": True,
- "reduce_scatter": True,
- "reduce_bucket_size": 104857600,
- "allgather_partitions": True,
- "allgather_bucket_size": 104857600
- },
- "gradient_clipping": 1.0,
- "train_micro_batch_size_per_gpu": 1,
- "gradient_accumulation_steps": 4,
- "save_only_model": True,
- "save_steps": 500,
- }
- ds_config_path = "temp_ds_config.json"
- with open(ds_config_path, "w") as f:
- json.dump(ds_config, f)
- # TrainingArguments
- training_args = TrainingArguments(
- output_dir=output_dir,
- overwrite_output_dir=True,
- remove_unused_columns=False,
- max_steps=4000,
- per_device_train_batch_size=1,
- save_steps=1000,
- logging_steps=10,
- evaluation_strategy="no",
- bf16=True, # matches ds_config
- gradient_checkpointing=True,
- gradient_accumulation_steps=4,
- deepspeed=ds_config_path,
- )
- # Create Trainer
- trainer = Trainer(
- model=student,
- args=training_args,
- train_dataset=train_dataset,
- tokenizer=tokenizer,
- data_collator=data_collator,
- )
- print(f"Starting training for {output_dir} with DeepSpeed ZeRO ...")
- trainer.train()
- # Save final model
- final_save_dir = "final_" + output_dir
- trainer.save_model(final_save_dir)
- print(f"✅ Done. Model saved to {final_save_dir}.")
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement