Advertisement
Guest User

Orpheus Local Finetune via Unsloth

a guest
Apr 3rd, 2025
60
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.77 KB | Source Code | 0 0
  1. from unsloth import FastLanguageModel
  2. import torch
  3. import locale
  4. import torchaudio.transforms as T
  5. import os
  6. from snac import SNAC
  7. from datasets import load_dataset
  8. from transformers import TrainingArguments,Trainer,DataCollatorForSeq2Seq
  9. from unsloth import is_bfloat16_supported
  10.  
  11. dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
  12. load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
  13. model, tokenizer = FastLanguageModel.from_pretrained(
  14.     model_name = "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit",
  15.     max_seq_length= 8192, # Choose any for long context!
  16.     dtype = dtype,
  17.     load_in_4bit = load_in_4bit
  18. )
  19.  
  20. model = FastLanguageModel.get_peft_model(
  21.     model,
  22.     r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
  23.     target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
  24.                       "gate_proj", "up_proj", "down_proj",],
  25.     lora_alpha = 64,
  26.     lora_dropout = 0, # Supports any, but = 0 is optimized
  27.     bias = "none",    # Supports any, but = "none" is optimized
  28.     # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
  29.     use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
  30.     random_state = 3407,
  31.     use_rslora = False,  # We support rank stabilized LoRA
  32.     loftq_config = None, # And LoftQ
  33. )
  34.  
  35. dataset = load_dataset("audiofolder", data_dir="myDataset", split = "train")
  36. locale.getpreferredencoding = lambda: "UTF-8"
  37. ds_sample_rate = dataset[0]["audio"]["sampling_rate"]
  38.  
  39. snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
  40. snac_model = snac_model.to("cuda")
  41.  
  42. def tokenise_audio(waveform):
  43.   waveform = torch.from_numpy(waveform).unsqueeze(0)
  44.   waveform = waveform.to(dtype=torch.float32)
  45.   resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
  46.   waveform = resample_transform(waveform)
  47.  
  48.   waveform = waveform.unsqueeze(0).to("cuda")
  49.  
  50.   #generate the codes from snac
  51.   with torch.inference_mode():
  52.     codes = snac_model.encode(waveform)
  53.  
  54.   all_codes = []
  55.   for i in range(codes[0].shape[1]):
  56.     all_codes.append(codes[0][0][i].item()+128266)
  57.     all_codes.append(codes[1][0][2*i].item()+128266+4096)
  58.     all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
  59.     all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
  60.     all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
  61.     all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
  62.     all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
  63.  
  64.  
  65.   return all_codes
  66.  
  67. def add_codes(example):
  68.     # Always initialize codes_list to None
  69.     codes_list = None
  70.  
  71.     try:
  72.         answer_audio = example.get("audio")
  73.         # If there's a valid audio array, tokenise it
  74.         if answer_audio and "array" in answer_audio:
  75.             audio_array = answer_audio["array"]
  76.             codes_list = tokenise_audio(audio_array)
  77.     except Exception as e:
  78.         print(f"Skipping row due to error: {e}")
  79.         # Keep codes_list as None if we fail
  80.     example["codes_list"] = codes_list
  81.  
  82.     return example
  83.  
  84. dataset = dataset.map(add_codes, remove_columns=["audio"])
  85.  
  86. tokeniser_length = 128256
  87. start_of_text = 128000
  88. end_of_text = 128009
  89.  
  90. start_of_speech = tokeniser_length + 1
  91. end_of_speech = tokeniser_length + 2
  92.  
  93. start_of_human = tokeniser_length + 3
  94. end_of_human = tokeniser_length + 4
  95.  
  96. start_of_ai = tokeniser_length + 5
  97. end_of_ai =  tokeniser_length + 6
  98. pad_token = tokeniser_length + 7
  99.  
  100. audio_tokens_start = tokeniser_length + 10
  101.  
  102. dataset = dataset.filter(lambda x: x["codes_list"] is not None)
  103. dataset = dataset.filter(lambda x: len(x["codes_list"]) > 0)
  104.  
  105. def remove_duplicate_frames(example):
  106.     vals = example["codes_list"]
  107.     if len(vals) % 7 != 0:
  108.         raise ValueError("Input list length must be divisible by 7")
  109.  
  110.     result = vals[:7]
  111.  
  112.     removed_frames = 0
  113.  
  114.     for i in range(7, len(vals), 7):
  115.         current_first = vals[i]
  116.         previous_first = result[-7]
  117.  
  118.         if current_first != previous_first:
  119.             result.extend(vals[i:i+7])
  120.         else:
  121.             removed_frames += 1
  122.  
  123.     example["codes_list"] = result
  124.  
  125.     return example
  126.  
  127. dataset = dataset.map(remove_duplicate_frames)
  128.  
  129. tok_info = '''*** HERE you can modify the text prompt
  130. If you are training a multi-speaker model (e.g., canopylabs/orpheus-3b-0.1-ft),
  131. ensure that the dataset includes a "source" field and format the input accordingly:
  132. - Single-speaker: f"{example['text']}"
  133. - Multi-speaker: f"{example['source']}: {example['text']}"
  134. '''
  135. print(tok_info)
  136.  
  137. def create_input_ids(example):
  138.     # Determine whether to include the source field
  139.     text_prompt = f"{example['source']}: {example['text']}" if "source" in example else example["text"]
  140.  
  141.     text_ids = tokenizer.encode(text_prompt, add_special_tokens=True)
  142.     text_ids.append(end_of_text)
  143.  
  144.     example["text_tokens"] = text_ids
  145.     input_ids = (
  146.         [start_of_human]
  147.         + example["text_tokens"]
  148.         + [end_of_human]
  149.         + [start_of_ai]
  150.         + [start_of_speech]
  151.         + example["codes_list"]
  152.         + [end_of_speech]
  153.         + [end_of_ai]
  154.     )
  155.     example["input_ids"] = input_ids
  156.     example["labels"] = input_ids
  157.     example["attention_mask"] = [1] * len(input_ids)
  158.  
  159.     return example
  160.  
  161.  
  162. dataset = dataset.map(create_input_ids, remove_columns=["text", "codes_list"])
  163. columns_to_keep = ["input_ids", "labels", "attention_mask"]
  164. columns_to_remove = [col for col in dataset.column_names if col not in columns_to_keep]
  165.  
  166. dataset = dataset.remove_columns(columns_to_remove)
  167.  
  168. trainer = Trainer(
  169.     model = model,
  170.     train_dataset = dataset,
  171.     args = TrainingArguments(
  172.         per_device_train_batch_size = 1,
  173.         gradient_accumulation_steps = 4,
  174.         warmup_steps = 5,
  175.         # num_train_epochs = 1, # Set this for 1 full training run.
  176.         max_steps = 130,
  177.         learning_rate = 2e-4,
  178.         fp16 = not is_bfloat16_supported(),
  179.         bf16 = is_bfloat16_supported(),
  180.         logging_steps = 1,
  181.         optim = "adamw_8bit",
  182.         weight_decay = 0.01,
  183.         lr_scheduler_type = "linear",
  184.         seed = 3207,
  185.         output_dir = "outputs",
  186.         report_to = "none", # Use this for WandB etc
  187.     ),
  188. )
  189.  
  190. trainer_stats = trainer.train()
  191.  
  192. def gen_wav(prompts, chosen_voice):
  193.     FastLanguageModel.for_inference(model)  # Enable native 2x faster inference
  194.  
  195.     # Moving snac_model cuda to cpu
  196.     snac_model.to("cpu")
  197.  
  198.     prompts_ = [(f"{chosen_voice}: " + p) if chosen_voice else p for p in prompts]
  199.  
  200.     all_input_ids = []
  201.  
  202.     for prompt in prompts_:
  203.         input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  204.         all_input_ids.append(input_ids)
  205.  
  206.     start_token = torch.tensor([[128259]], dtype=torch.int64)  # Start of human
  207.     end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)  # End of text, End of human
  208.  
  209.     all_modified_input_ids = []
  210.     for input_ids in all_input_ids:
  211.         modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)  # SOH SOT Text EOT EOH
  212.         all_modified_input_ids.append(modified_input_ids)
  213.  
  214.     all_padded_tensors = []
  215.     all_attention_masks = []
  216.     max_length = max([modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids])
  217.     for modified_input_ids in all_modified_input_ids:
  218.         padding = max_length - modified_input_ids.shape[1]
  219.         padded_tensor = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1)
  220.         attention_mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64),
  221.                                     torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)], dim=1)
  222.         all_padded_tensors.append(padded_tensor)
  223.         all_attention_masks.append(attention_mask)
  224.  
  225.     all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
  226.     all_attention_masks = torch.cat(all_attention_masks, dim=0)
  227.  
  228.     input_ids = all_padded_tensors.to("cuda")
  229.     attention_mask = all_attention_masks.to("cuda")
  230.     generated_ids = model.generate(
  231.         input_ids=input_ids,
  232.         attention_mask=attention_mask,
  233.         max_new_tokens=8192,
  234.         do_sample=True,
  235.         temperature=0.7,
  236.         top_p=0.95,
  237.         repetition_penalty=1.1,
  238.         num_return_sequences=1,
  239.         eos_token_id=128258,
  240.         use_cache=True
  241.     )
  242.     token_to_find = 128257
  243.     token_to_remove = 128258
  244.  
  245.     token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
  246.  
  247.     if len(token_indices[1]) > 0:
  248.         last_occurrence_idx = token_indices[1][-1].item()
  249.         cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
  250.     else:
  251.         cropped_tensor = generated_ids
  252.  
  253.     mask = cropped_tensor != token_to_remove
  254.  
  255.     processed_rows = []
  256.     for row in cropped_tensor:
  257.         masked_row = row[row != token_to_remove]
  258.         processed_rows.append(masked_row)
  259.  
  260.     code_lists = []
  261.     for row in processed_rows:
  262.         row_length = row.size(0)
  263.         new_length = (row_length // 7) * 7
  264.         trimmed_row = row[:new_length]
  265.         trimmed_row = [t - 128266 for t in trimmed_row]
  266.         code_lists.append(trimmed_row)
  267.  
  268.     def redistribute_codes(code_list):
  269.         layer_1 = []
  270.         layer_2 = []
  271.         layer_3 = []
  272.         for i in range((len(code_list) + 1) // 7):
  273.             layer_1.append(code_list[7 * i])
  274.             layer_2.append(code_list[7 * i + 1] - 4096)
  275.             layer_3.append(code_list[7 * i + 2] - (2 * 4096))
  276.             layer_3.append(code_list[7 * i + 3] - (3 * 4096))
  277.             layer_2.append(code_list[7 * i + 4] - (4 * 4096))
  278.             layer_3.append(code_list[7 * i + 5] - (5 * 4096))
  279.             layer_3.append(code_list[7 * i + 6] - (6 * 4096))
  280.         codes = [torch.tensor(layer_1).unsqueeze(0),
  281.                  torch.tensor(layer_2).unsqueeze(0),
  282.                  torch.tensor(layer_3).unsqueeze(0)]
  283.         # codes = [c.to("cuda") for c in codes]
  284.         audio_hat = snac_model.decode(codes)
  285.         return audio_hat
  286.  
  287.     my_samples = []
  288.     for code_list in code_lists:
  289.         samples = redistribute_codes(code_list)
  290.         my_samples.append(samples)
  291.  
  292.     from scipy.io.wavfile import write
  293.     if len(prompts) != len(my_samples):
  294.         raise Exception("Number of prompts and samples do not match")
  295.     else:
  296.         for i in range(len(my_samples)):
  297.             print(prompts[i])
  298.             samples = my_samples[i]
  299.             audio_array = samples.detach().squeeze().to("cpu").numpy()
  300.             filename = f"output_{i}.wav"
  301.             write(filename, 24000, audio_array)
  302.             print(f"Saved audio to {filename}")
  303.  
  304.     # Clean up to save RAM
  305.     del my_samples, samples
  306.  
  307.  
  308. prompts = [
  309.     "A quick brown fox jumped over the lazy dog!",
  310. ]
  311.  
  312. chosen_voice = None # None for single-speaker
  313. def gen_wav(prompts, chosen_voice)
  314.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement