Guest User

Orpheus via llama.cpp server

a guest
Mar 27th, 2025
301
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.90 KB | Source Code | 0 0
  1. # Based on https://github.com/freddyaboulton/orpheus-cpp/
  2.  
  3. import argparse
  4. import asyncio
  5. import json
  6. import platform
  7. import requests
  8. import soundfile
  9. import threading
  10. import winsound
  11. from typing import (
  12.     AsyncGenerator,
  13.     Generator,
  14.     Iterator,
  15.     Literal,
  16.     NotRequired,
  17.     TypedDict,
  18.     cast,
  19. )
  20.  
  21. import numpy as np
  22. import onnxruntime
  23. from numpy.typing import NDArray
  24.  
  25.  
  26. class TTSOptions(TypedDict):
  27.     max_tokens: NotRequired[int]
  28.     """Maximum number of tokens to generate. Default: 2048"""
  29.     temperature: NotRequired[float]
  30.     """Temperature for top-p sampling. Default: 0.8"""
  31.     top_p: NotRequired[float]
  32.     """Top-p sampling. Default: 0.95"""
  33.     top_k: NotRequired[int]
  34.     """Top-k sampling. Default: 40"""
  35.     min_p: NotRequired[float]
  36.     """Minimum probability for top-p sampling. Default: 0.05"""
  37.     pre_buffer_size: NotRequired[float]
  38.     """Seconds of audio to generate before yielding the first chunk. Smoother audio streaming at the cost of higher time to wait for the first chunk."""
  39.     voice_id: NotRequired[
  40.         Literal["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
  41.     ]
  42.     """The voice to use for the TTS. Default: "tara"."""
  43.  
  44.  
  45. CUSTOM_TOKEN_PREFIX = "<custom_token_"
  46.  
  47.  
  48. class OrpheusCpp:
  49.     def __init__(self, verbose: bool = True):
  50.         import importlib.util
  51.  
  52.         snac_model_path = "snac_decoder_model.onnx"
  53.  
  54.         # Load SNAC model with optimizations
  55.         self._snac_session = onnxruntime.InferenceSession(
  56.             snac_model_path,
  57.             providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
  58.         )
  59.  
  60.     def _token_to_id(self, token_text: str, index: int) -> int | None:
  61.         token_string = token_text.strip()
  62.  
  63.         # Find the last token in the string
  64.         last_token_start = token_string.rfind(CUSTOM_TOKEN_PREFIX)
  65.  
  66.         if last_token_start == -1:
  67.             return None
  68.  
  69.         # Extract the last token
  70.         last_token = token_string[last_token_start:]
  71.  
  72.         # Process the last token
  73.         if last_token.startswith(CUSTOM_TOKEN_PREFIX) and last_token.endswith(">"):
  74.             try:
  75.                 number_str = last_token[14:-1]
  76.                 token_id = int(number_str) - 10 - ((index % 7) * 4096)
  77.                 return token_id
  78.             except ValueError:
  79.                 return None
  80.         else:
  81.             return None
  82.  
  83.     def _decode(
  84.         self, token_gen: Generator[str, None, None]
  85.     ) -> Generator[np.ndarray, None, None]:
  86.         """Asynchronous token decoder that converts token stream to audio stream."""
  87.         buffer = []
  88.         count = 0
  89.         for token_text in token_gen:
  90.             token = self._token_to_id(token_text, count)
  91.             if token is not None and token > 0:
  92.                 buffer.append(token)
  93.                 count += 1
  94.  
  95.                 # Convert to audio when we have enough tokens
  96.                 if count % 7 == 0 and count > 27:
  97.                     buffer_to_proc = buffer[-28:]
  98.                     audio_samples = self._convert_to_audio(buffer_to_proc)
  99.                     if audio_samples is not None:
  100.                         yield audio_samples
  101.  
  102.     def _convert_to_audio(self, multiframe: list[int]) -> np.ndarray | None:
  103.         if len(multiframe) < 28:  # Ensure we have enough tokens
  104.             return None
  105.  
  106.         num_frames = len(multiframe) // 7
  107.         frame = multiframe[: num_frames * 7]
  108.  
  109.         # Initialize empty numpy arrays instead of torch tensors
  110.         codes_0 = np.array([], dtype=np.int32)
  111.         codes_1 = np.array([], dtype=np.int32)
  112.         codes_2 = np.array([], dtype=np.int32)
  113.  
  114.         for j in range(num_frames):
  115.             i = 7 * j
  116.             # Append values to numpy arrays
  117.             codes_0 = np.append(codes_0, frame[i])
  118.  
  119.             codes_1 = np.append(codes_1, [frame[i + 1], frame[i + 4]])
  120.  
  121.             codes_2 = np.append(
  122.                 codes_2, [frame[i + 2], frame[i + 3], frame[i + 5], frame[i + 6]]
  123.             )
  124.  
  125.         # Reshape arrays to match the expected input format (add batch dimension)
  126.         codes_0 = np.expand_dims(codes_0, axis=0)
  127.         codes_1 = np.expand_dims(codes_1, axis=0)
  128.         codes_2 = np.expand_dims(codes_2, axis=0)
  129.  
  130.         # Check that all tokens are between 0 and 4096
  131.         if (
  132.             np.any(codes_0 < 0)
  133.             or np.any(codes_0 > 4096)
  134.             or np.any(codes_1 < 0)
  135.             or np.any(codes_1 > 4096)
  136.             or np.any(codes_2 < 0)
  137.             or np.any(codes_2 > 4096)
  138.         ):
  139.             return None
  140.  
  141.         # Create input dictionary for ONNX session
  142.  
  143.         snac_input_names = [x.name for x in self._snac_session.get_inputs()]
  144.  
  145.         input_dict = dict(zip(snac_input_names, [codes_0, codes_1, codes_2]))
  146.  
  147.         # Run inference
  148.         audio_hat = self._snac_session.run(None, input_dict)[0]
  149.  
  150.         # Process output
  151.         audio_np = audio_hat[:, :, 2048:4096]
  152.         audio_int16 = (audio_np * 32767).astype(np.int16)
  153.         audio_bytes = audio_int16.tobytes()
  154.         return audio_bytes
  155.  
  156.     def tts(
  157.         self, text: str, options: TTSOptions | None = None
  158.     ) -> tuple[int, NDArray[np.int16]]:
  159.         buffer = []
  160.         for _, array in self.stream_tts_sync(text, options):
  161.             buffer.append(array)
  162.         return (24_000, np.concatenate(buffer, axis=1))
  163.  
  164.     async def stream_tts(
  165.         self, text: str, options: TTSOptions | None = None
  166.     ) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]:
  167.         queue = asyncio.Queue()
  168.         finished = asyncio.Event()
  169.  
  170.         def strem_to_queue(text, options, queue, finished):
  171.             for chunk in self.stream_tts_sync(text, options):
  172.                 queue.put_nowait(chunk)
  173.             finished.set()
  174.  
  175.         thread = threading.Thread(
  176.             target=strem_to_queue, args=(text, options, queue, finished)
  177.         )
  178.         thread.start()
  179.         while not finished.is_set():
  180.             try:
  181.                 yield await asyncio.wait_for(queue.get(), 0.1)
  182.             except (asyncio.TimeoutError, TimeoutError):
  183.                 pass
  184.         while not queue.empty():
  185.             chunk = queue.get_nowait()
  186.             yield chunk
  187.  
  188.     def _token_gen(
  189.         self, text: str, options: TTSOptions | None = None
  190.     ) -> Generator[str, None, None]:
  191.  
  192.         options = options or TTSOptions()
  193.         voice_id = options.get("voice_id", "tara")
  194.         text = f"<|audio|>{voice_id}: {text}<|eot_id|><custom_token_4>"
  195.         completion_url = "http://localhost:8080/completion"
  196.         data = {
  197.             "stream": True,
  198.             "prompt": text,
  199.             "max_tokens": options.get("max_tokens", 2_048),
  200.             "temperature": options.get("temperature", 0.8),
  201.             "top_p": options.get("top_p", 0.95),
  202.             "top_k": options.get("top_k", 40),
  203.             "min_p": options.get("min_p", 0.05),
  204.         }
  205.         response = requests.post(completion_url, json=data, stream=True)
  206.         for line in response.iter_lines():
  207.             line = line.decode("utf-8")
  208.  
  209.             if line.startswith("data: ") and not line.endswith("[DONE]"):
  210.                 data = json.loads(line[len("data: "):])
  211.                 yield data["content"]
  212.  
  213.  
  214.     def stream_tts_sync(
  215.         self, text: str, options: TTSOptions | None = None
  216.     ) -> Generator[tuple[int, NDArray[np.int16]], None, None]:
  217.         options = options or TTSOptions()
  218.         token_gen = self._token_gen(text, options)
  219.         pre_buffer = np.array([], dtype=np.int16).reshape(1, 0)
  220.         pre_buffer_size = 24_000 * options.get("pre_buffer_size", 1.5)
  221.         started_playback = False
  222.         for audio_bytes in self._decode(token_gen):
  223.             audio_array = np.frombuffer(audio_bytes, dtype=np.int16).reshape(1, -1)
  224.             if not started_playback:
  225.                 pre_buffer = np.concatenate([pre_buffer, audio_array], axis=1)
  226.                 if pre_buffer.shape[1] >= pre_buffer_size:
  227.                     started_playback = True
  228.                     yield (24_000, pre_buffer)
  229.             else:
  230.                 yield (24_000, audio_array)
  231.         if not started_playback:
  232.             yield (24_000, pre_buffer)
  233.  
  234. def main():
  235.     parser = argparse.ArgumentParser(description="Text-to-Speech with OrpheusCpp")
  236.     parser.add_argument("--text", type=str, help="The text to convert to speech. You can use these tags: <giggle>, <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>")
  237.     parser.add_argument("--voice", type=str, choices=["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"], default="tara", help="The voice to use for the TTS")
  238.     args = parser.parse_args()
  239.  
  240.     orpheus = OrpheusCpp()
  241.     sample_rate, samples = orpheus.tts(args.text.strip(), options={"voice_id": args.voice, "temperature": 0.3})
  242.     soundfile.write("output.wav", samples.squeeze(), sample_rate)
  243.     winsound.PlaySound("output.wav", winsound.SND_FILENAME)
  244.  
  245. if __name__ == "__main__":
  246.     main()
Advertisement
Add Comment
Please, Sign In to add comment