Advertisement
sahuraj0909

Orpheus tts api server

May 1st, 2025
60
0
5 days
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 23.42 KB | Source Code | 0 0
  1. # =============================================================================
  2. # Standard Library Imports
  3. # =============================================================================
  4. import json
  5. import logging
  6. import os
  7. import re
  8. import time
  9. import traceback
  10. from typing import Any, Dict, Generator, List, Optional, Tuple, Union
  11.  
  12. # =============================================================================
  13. # Third-Party Library Imports
  14. # =============================================================================
  15. import numpy as np
  16. import requests
  17. import torch
  18. from torch import nn
  19. from fastapi import FastAPI, HTTPException, Request
  20. from fastapi.responses import StreamingResponse
  21. from fastapi.middleware.cors import CORSMiddleware
  22. from pydantic import BaseModel
  23. import uvicorn # For running the server
  24.  
  25. # =============================================================================
  26. # Library Imports with Error Handling & Conditional Imports
  27. # =============================================================================
  28. logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
  29. logger = logging.getLogger(__name__)
  30.  
  31. try:
  32.     from snac import SNAC
  33.     logger.info("SNAC imported.")
  34. except ImportError:
  35.     logger.error("SNAC not found. pip install git+https://github.com/hubertsiuzdak/snac.git")
  36.     SNAC = None # Set to None if import fails
  37.  
  38. # =============================================================================
  39. # Configuration Loading
  40. # =============================================================================
  41. # --- API Endpoints & Model Names ---
  42. # Use the TTS_API_ENDPOINT for the upstream service (like LM Studio)
  43. # UPSTREAM_TTS_API_ENDPOINT = os.getenv("UPSTREAM_TTS_API_ENDPOINT", "http://127.0.0.1:1234/v1/completions")
  44. UPSTREAM_TTS_API_ENDPOINT = os.getenv("UPSTREAM_TTS_API_ENDPOINT", "http://127.0.0.1:8080/v1/completions")
  45. # TTS_MODEL = os.getenv("TTS_MODEL", "isaiahbjork/orpheus-3b-0.1-ft")
  46. TTS_MODEL = os.getenv("TTS_MODEL", "lex-au/Orpheus-3b-FT-Q2_K.gguf")
  47.  
  48. # --- Prompts ---
  49. TTS_PROMPT_FORMAT = "<|audio|>{voice}: {text}<|eot_id|>"
  50. TTS_PROMPT_STOP_TOKENS = ["<|eot_id|>", "<|audio|>"]
  51.  
  52. logger.info(f"Upstream TTS Endpoint: {UPSTREAM_TTS_API_ENDPOINT}, TTS Model: {TTS_MODEL}")
  53.  
  54. # =============================================================================
  55. # Constants
  56. # =============================================================================
  57. # --- TTS Default Parameters ---
  58. DEFAULT_TTS_TEMP = 0.8
  59. DEFAULT_TTS_TOP_P = 0.9
  60. DEFAULT_TTS_REP_PENALTY = 1.1
  61.  
  62. # --- Orpheus/SNAC Specific Constants ---
  63. ORPHEUS_MIN_ID = 10
  64. ORPHEUS_TOKENS_PER_LAYER = 4096
  65. ORPHEUS_N_LAYERS = 7
  66. ORPHEUS_MAX_ID = ORPHEUS_MIN_ID + (ORPHEUS_N_LAYERS * ORPHEUS_TOKENS_PER_LAYER)
  67.  
  68. # --- Audio Processing & Misc ---
  69. TARGET_SAMPLE_RATE = 24000 # Crucial for client playback!
  70. AUDIO_DTYPE = np.float32 # SNAC outputs float32
  71.  
  72. # --- Streaming TTS Constants ---
  73. DEFAULT_TTS_STREAM_MIN_GROUPS = 40
  74. DEFAULT_TTS_STREAM_SILENCE_MS = 5
  75.  
  76. # --- API Communication ---
  77. API_TIMEOUT_SECONDS = 180
  78. STREAM_TIMEOUT_SECONDS = 300
  79. STREAM_HEADERS = {"Content-Type": "application/json", "Accept": "text/event-stream"}
  80. SSE_DATA_PREFIX = "data:"
  81. SSE_DONE_MARKER = "[DONE]"
  82.  
  83. # --- Voice Constants ---
  84. ALL_VOICES = ["tara", "jess", "leo", "leah", "dan", "mia", "zac", "zoe"]
  85. DEFAULT_TTS_VOICE = ALL_VOICES[0]
  86.  
  87. # =============================================================================
  88. # Device Setup
  89. # =============================================================================
  90. tts_device = "cuda" if torch.cuda.is_available() else "cpu"
  91. logger.info(f"TTS Device: '{tts_device}'")
  92.  
  93. # =============================================================================
  94. # Utility Functions (Copied from original script)
  95. # =============================================================================
  96. def parse_gguf_codes(response_text: str) -> List[int]:
  97.     """Parse Orpheus <custom_token_ID> from text."""
  98.     try:
  99.         logger.debug(f"PARSING FOR CODES IN: {response_text[:200]}...")
  100.  
  101.         codes = [
  102.             int(m) for m in re.findall(r"<custom_token_(\d+)>", response_text)
  103.             if ORPHEUS_MIN_ID <= int(m) < ORPHEUS_MAX_ID
  104.         ]
  105.        
  106.         if codes:
  107.             logger.debug(f"FOUND {len(codes)} CODES: first few = {codes[:10]}...")
  108.         else:
  109.             logger.debug(f"NO CODES FOUND using pattern '<custom_token_(\\d+)>'")
  110.        
  111.         return codes
  112.     except Exception as e:
  113.         logger.error(f"GGUF parse error: {e}")
  114.         return []
  115.  
  116. def redistribute_codes(codes: List[int], model: nn.Module) -> Optional[np.ndarray]:
  117.     """Convert absolute Orpheus token IDs to SNAC input tensors and decode audio."""
  118.     if not codes or model is None:
  119.         return None
  120.  
  121.     try:
  122.         dev = next(model.parameters()).device
  123.         layers: List[List[int]] = [[], [], []]
  124.         groups = len(codes) // ORPHEUS_N_LAYERS
  125.  
  126.         if groups == 0:
  127.             return None
  128.  
  129.         valid = 0
  130.         for i in range(groups):
  131.             idx = i * ORPHEUS_N_LAYERS
  132.             group = codes[idx : idx + ORPHEUS_N_LAYERS]
  133.             processed: List[Optional[int]] = [None] * ORPHEUS_N_LAYERS
  134.             ok = True
  135.  
  136.             for j, t_id in enumerate(group):
  137.                 if not (ORPHEUS_MIN_ID <= t_id < ORPHEUS_MAX_ID):
  138.                     ok = False; break
  139.                 layer_idx = (t_id - ORPHEUS_MIN_ID) // ORPHEUS_TOKENS_PER_LAYER
  140.                 code_idx = (t_id - ORPHEUS_MIN_ID) % ORPHEUS_TOKENS_PER_LAYER
  141.                 if layer_idx != j:
  142.                     ok = False; break
  143.                 processed[j] = code_idx
  144.  
  145.             if ok:
  146.                 try:
  147.                     if any(c is None for c in processed): continue
  148.                     pg: List[int] = processed
  149.                     layers[0].append(pg[0]); layers[1].append(pg[1]); layers[2].append(pg[2])
  150.                     layers[2].append(pg[3]); layers[1].append(pg[4]); layers[2].append(pg[5])
  151.                     layers[2].append(pg[6]); valid += 1
  152.                 except (IndexError, TypeError) as map_e:
  153.                     logger.error(f"Code map error in group {i}: {map_e}"); continue
  154.  
  155.         if valid == 0:
  156.             logger.warning("No valid code groups found after processing.")
  157.             return None
  158.         if not all(layers):
  159.             logger.error("SNAC layers empty after processing valid groups.")
  160.             return None
  161.  
  162.         tensors = [ torch.tensor(lc, device=dev, dtype=torch.long).unsqueeze(0) for lc in layers ]
  163.         with torch.no_grad():
  164.             audio = model.decode(tensors)
  165.  
  166.         # Ensure output is float32 numpy array
  167.         return audio.detach().squeeze().cpu().to(torch.float32).numpy()
  168.  
  169.     except Exception as e:
  170.         logger.exception("SNAC decode error during tensor creation or decoding.")
  171.         return None
  172.  
  173. def apply_fade(audio_chunk: np.ndarray, sample_rate: int, fade_ms: int = 3) -> np.ndarray:
  174.     """Apply a short linear fade-in and fade-out to an audio chunk."""
  175.     num_fade_samples = int(sample_rate * (fade_ms / 1000.0))
  176.  
  177.     if num_fade_samples <= 0 or audio_chunk.size < 3 * num_fade_samples:
  178.         return audio_chunk
  179.  
  180.     fade_in = np.linspace(0., 1., num_fade_samples, dtype=audio_chunk.dtype)
  181.     fade_out = np.linspace(1., 0., num_fade_samples, dtype=audio_chunk.dtype)
  182.  
  183.     chunk_copy = audio_chunk.copy() # Ensure we work on a copy
  184.     chunk_copy[:num_fade_samples] *= fade_in
  185.     chunk_copy[-num_fade_samples:] *= fade_out
  186.  
  187.     return chunk_copy
  188.  
  189. # =============================================================================
  190. # Model Loading
  191. # =============================================================================
  192. snac_model: Optional[SNAC] = None
  193. if SNAC:
  194.     logger.info("--- Loading Local SNAC Model ---")
  195.     try:
  196.         import warnings
  197.         warnings.filterwarnings("ignore", category=FutureWarning, module="snac.snac")
  198.         snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
  199.         if snac_model:
  200.             snac_model = snac_model.to(tts_device).eval()
  201.             logger.info(f"SNAC loaded to '{tts_device}'.")
  202.             # Optional Warmup
  203.             try:
  204.                 logger.info("Attempting SNAC warm-up...")
  205.                 dummy_tokens = [
  206.                     min(ORPHEUS_MIN_ID + i * ORPHEUS_TOKENS_PER_LAYER + 100, ORPHEUS_MAX_ID - 1)
  207.                     for i in range(ORPHEUS_N_LAYERS)
  208.                 ] * 10 # Small warmup sequence
  209.                 warmup_audio = redistribute_codes(dummy_tokens, snac_model)
  210.                 if warmup_audio is not None and warmup_audio.size > 0:
  211.                     logger.info(f"SNAC warm-up OK (produced {warmup_audio.size} samples).")
  212.                 else:
  213.                     logger.warning("SNAC warm-up call ran but produced no audio.")
  214.             except Exception as wu_e:
  215.                 logger.exception("SNAC warm-up call failed.")
  216.         else:
  217.             logger.error("SNAC.from_pretrained returned None. Model not loaded.")
  218.             snac_model = None
  219.     except Exception as e:
  220.         logger.exception("Fatal error loading SNAC.")
  221.         snac_model = None
  222. else:
  223.     logger.critical("SNAC library not found. TTS server cannot function.")
  224.     # Optionally exit here if SNAC is mandatory
  225.     # exit(1)
  226.  
  227. if not snac_model:
  228.     logger.critical("SNAC model failed to load. TTS endpoint will return errors.")
  229.  
  230. # =============================================================================
  231. # TTS Pipeline Function (Modified for FastAPI StreamingResponse)
  232. # =============================================================================
  233. async def generate_speech_stream_bytes(
  234.     text: str,
  235.     voice: str,
  236.     tts_temperature: float,
  237.     tts_top_p: float,
  238.     tts_repetition_penalty: float,
  239.     buffer_groups_param: int,
  240.     padding_ms_param: int,
  241. ) -> Generator[bytes, None, None]:
  242.     """Generates audio chunk bytes via TTS streaming API + local SNAC."""
  243.  
  244.     if not snac_model:
  245.         logger.error("generate_speech_stream_bytes called but snac_model is not loaded.")
  246.         # Yield an empty byte string or handle differently if preferred
  247.         # raise HTTPException(status_code=503, detail="SNAC Vocoder model not available") # Could raise here, but generator might handle it smoother
  248.         yield b'' # Return empty bytes to signal failure gracefully to client?
  249.         return
  250.  
  251.     if not text.strip():
  252.         logger.warning("generate_speech_stream_bytes called with empty text.")
  253.         yield b''
  254.         return
  255.  
  256.     min_codes_required = buffer_groups_param * ORPHEUS_N_LAYERS
  257.     logger.debug(f"Stream processing: buffer={buffer_groups_param} groups ({min_codes_required} codes), padding={padding_ms_param} ms")
  258.  
  259.     silence_samples: int = 0
  260.     if padding_ms_param > 0:
  261.         silence_samples = int(TARGET_SAMPLE_RATE * (padding_ms_param / 1000.0))
  262.         logger.debug(f"Calculated silence samples per side: {silence_samples}")
  263.  
  264.     payload = {
  265.         "model": TTS_MODEL,
  266.         "prompt": TTS_PROMPT_FORMAT.format(voice=voice, text=text),
  267.         "temperature": tts_temperature,
  268.         "top_p": tts_top_p,
  269.         "repeat_penalty": tts_repetition_penalty,
  270.         "n_predict": -1, # Predict until stop token or model limit
  271.         "stop": TTS_PROMPT_STOP_TOKENS,
  272.         "stream": True
  273.     }
  274.  
  275.     accumulated_codes: List[int] = []
  276.     response = None
  277.     stream_start_time = time.time()
  278.     any_audio_yielded = False
  279.  
  280.     try:
  281.         logger.info(">>> TTS API: Initiating stream request to upstream...")
  282.         # Using requests synchronously within an async function is generally okay for I/O bound tasks
  283.         # but for high-performance scenarios, consider an async HTTP client like httpx.
  284.         with requests.post(
  285.             UPSTREAM_TTS_API_ENDPOINT, json=payload, headers=STREAM_HEADERS, stream=True, timeout=STREAM_TIMEOUT_SECONDS
  286.         ) as response:
  287.             response.raise_for_status()
  288.             logger.info(f"--- TTS API: Upstream stream connected after {time.time() - stream_start_time:.3f}s. Receiving codes...")
  289.  
  290.             for line in response.iter_lines():
  291.                 if not line: continue
  292.                 try:
  293.                     decoded_line = line.decode(response.encoding or 'utf-8')
  294.                     logger.debug(f"RAW LINE: {decoded_line[:200]}...")  # Log the raw SSE line
  295.  
  296.                 except UnicodeDecodeError:
  297.                     logger.warning(f"Skipping undecodable line: {line[:50]}...")
  298.                     continue
  299.  
  300.                 if decoded_line.startswith(SSE_DATA_PREFIX):
  301.                     json_str = decoded_line[len(SSE_DATA_PREFIX):].strip()
  302.                     if json_str == SSE_DONE_MARKER:
  303.                         logger.debug("Received TTS SSE_DONE_MARKER from upstream.")
  304.                         break
  305.                     if not json_str: continue
  306.  
  307.                     try:
  308.                         data = json.loads(json_str)
  309.                         logger.debug(f"PARSED JSON STRUCTURE: {json.dumps(data, indent=2)[:500]}...")
  310.  
  311.                         # chunk_text = ""
  312.                         # print("data is ", data)
  313.                         # # Adapt parsing based on expected upstream API response structure
  314.                         # if "content" in data:
  315.                         #      chunk_text = data.get("content", "")
  316.                         # elif "choices" in data and data["choices"]:
  317.                         #      choice = data["choices"][0]
  318.                         #      delta = choice.get("delta", {})
  319.                         #      chunk_text = delta.get("content", "") or choice.get("text", "") # Handle different possible keys
  320.                         chunk_text = ""
  321.  
  322.                         try:
  323.                             choices = data.get("choices", [])
  324.                             if choices:
  325.                                 first_choice = choices[0]
  326.                                 chunk_text = (
  327.                                     first_choice.get("delta", {}).get("content")
  328.                                     or first_choice.get("text")
  329.                                     or ""
  330.                                 )
  331.                         except (AttributeError, IndexError, TypeError):
  332.                             chunk_text = ""
  333.  
  334.  
  335.                         if chunk_text:
  336.                             new_codes = parse_gguf_codes(chunk_text)
  337.                             if new_codes:
  338.                                 accumulated_codes.extend(new_codes)
  339.                                 # Process codes if enough are accumulated
  340.                                 if len(accumulated_codes) >= min_codes_required:
  341.                                     num_groups_to_decode = len(accumulated_codes) // ORPHEUS_N_LAYERS
  342.                                     codes_to_decode = accumulated_codes[:num_groups_to_decode * ORPHEUS_N_LAYERS]
  343.                                     accumulated_codes = accumulated_codes[num_groups_to_decode * ORPHEUS_N_LAYERS:]
  344.  
  345.                                     snac_start_time = time.time()
  346.                                     audio_chunk = redistribute_codes(codes_to_decode, snac_model)
  347.                                     snac_end_time = time.time()
  348.  
  349.                                     if audio_chunk is not None and audio_chunk.size > 0:
  350.                                         logger.debug(f"--- SNAC: Decoded chunk ({len(codes_to_decode)} codes -> {audio_chunk.size} samples) in {snac_end_time - snac_start_time:.3f}s.")
  351.                                         faded_chunk = apply_fade(audio_chunk, TARGET_SAMPLE_RATE, fade_ms=3)
  352.  
  353.                                         # Add padding if needed
  354.                                         if silence_samples > 0:
  355.                                             silence = np.zeros(silence_samples, dtype=AUDIO_DTYPE)
  356.                                             final_chunk = np.concatenate((silence, faded_chunk, silence))
  357.                                         else:
  358.                                             final_chunk = faded_chunk
  359.  
  360.                                         yield final_chunk.tobytes()
  361.                                         any_audio_yielded = True
  362.                                     else:
  363.                                          logger.warning(f"--- SNAC: Failed to decode chunk ({len(codes_to_decode)} codes) in {snac_end_time - snac_start_time:.3f}s.")
  364.  
  365.                         # Check for stop conditions from upstream API
  366.                         stop_reason=None
  367.                         is_stopped=False
  368.                         if "choices" in data and data["choices"]:
  369.                             stop_reason=data["choices"][0].get("finish_reason")
  370.                         # Check common stop flags from llama.cpp server
  371.                         if stop_reason or data.get("stop") or data.get("stopped_eos") or data.get("stopped_limit"):
  372.                              is_stopped=True
  373.                              logger.debug(f"Upstream TTS Stream stop condition met: reason='{stop_reason}', data flags: stop={data.get('stop')}, eos={data.get('stopped_eos')}, limit={data.get('stopped_limit')}")
  374.                         if is_stopped:
  375.                             break
  376.  
  377.                     except json.JSONDecodeError:
  378.                         logger.warning(f"Skipping invalid JSON in upstream TTS stream: {json_str[:100]}...")
  379.                         continue
  380.                     except Exception as e:
  381.                         logger.exception(f"Error processing upstream TTS stream chunk: {json_str[:100]}...")
  382.                         continue # Or break depending on desired behavior
  383.  
  384.             # Process any remaining codes after the loop
  385.             if len(accumulated_codes) >= ORPHEUS_N_LAYERS:
  386.                 logger.debug(f"Processing final {len(accumulated_codes)} codes after stream end.")
  387.                 num_groups = len(accumulated_codes) // ORPHEUS_N_LAYERS
  388.                 codes_to_decode = accumulated_codes[:num_groups * ORPHEUS_N_LAYERS]
  389.  
  390.                 snac_start_time = time.time()
  391.                 audio_chunk = redistribute_codes(codes_to_decode, snac_model)
  392.                 snac_end_time = time.time()
  393.  
  394.                 if audio_chunk is not None and audio_chunk.size > 0:
  395.                     logger.debug(f"--- SNAC: Decoded final chunk ({len(codes_to_decode)} codes -> {audio_chunk.size} samples) in {snac_end_time - snac_start_time:.3f}s.")
  396.                     faded_chunk = apply_fade(audio_chunk, TARGET_SAMPLE_RATE, fade_ms=3)
  397.                     if silence_samples > 0:
  398.                         silence = np.zeros(silence_samples, dtype=AUDIO_DTYPE)
  399.                         final_chunk = np.concatenate((silence, faded_chunk, silence))
  400.                     else:
  401.                         final_chunk = faded_chunk
  402.                     yield final_chunk.tobytes()
  403.                     any_audio_yielded = True
  404.                 else:
  405.                      logger.warning(f"--- SNAC: Failed to decode final chunk ({len(codes_to_decode)} codes) in {snac_end_time - snac_start_time:.3f}s.")
  406.             elif accumulated_codes:
  407.                  logger.debug(f"Discarding final {len(accumulated_codes)} codes (less than {ORPHEUS_N_LAYERS}).")
  408.  
  409.     except requests.exceptions.RequestException as e:
  410.         logger.exception(f"<<< TTS API: Upstream RequestException after {time.time() - stream_start_time:.3f}s.")
  411.         # Optionally yield an error indicator or raise within the generator context if FastAPI handles it
  412.         # yield b'ERROR: Upstream connection failed' # Example of sending text error
  413.     except Exception as e:
  414.         logger.exception(f"<<< TTS API: Unexpected error during upstream streaming after {time.time() - stream_start_time:.3f}s.")
  415.         # yield b'ERROR: Internal server error'
  416.     finally:
  417.         if not any_audio_yielded:
  418.              logger.warning(f"<<< TTS API: Stream finished but NO audio was generated/yielded.")
  419.         logger.info(f"<<< TTS API: Upstream stream processing finished after {time.time() - stream_start_time:.3f}s.")
  420.  
  421.  
  422. # =============================================================================
  423. # FastAPI Application
  424. # =============================================================================
  425. app = FastAPI(title="Streaming Orpheus TTS Server", version="1.0.0")
  426.  
  427. app.add_middleware(
  428.     CORSMiddleware,
  429.     allow_origins=["http://localhost:3000"],  # Your React app's URL (adjust port if needed)
  430.     allow_credentials=True,
  431.     allow_methods=["*"],  # Allow all methods or specify: ["GET", "POST", etc.]
  432.     allow_headers=["*"],  # Allow all headers
  433. )
  434.  
  435. class TTSRequest(BaseModel):
  436.     text: str
  437.     voice: str = DEFAULT_TTS_VOICE
  438.     temperature: float = DEFAULT_TTS_TEMP
  439.     top_p: float = DEFAULT_TTS_TOP_P
  440.     repetition_penalty: float = DEFAULT_TTS_REP_PENALTY
  441.     buffer_groups: int = DEFAULT_TTS_STREAM_MIN_GROUPS
  442.     padding_ms: int = DEFAULT_TTS_STREAM_SILENCE_MS
  443.  
  444. @app.on_event("startup")
  445. async def startup_event():
  446.     logger.info("Server starting up.")
  447.     if not snac_model:
  448.         logger.warning("SNAC model is not loaded. /stream-tts endpoint may fail.")
  449.  
  450. @app.on_event("shutdown")
  451. def shutdown_event():
  452.     logger.info("Server shutting down.")
  453.  
  454. @app.get("/health")
  455. async def health_check():
  456.     """Basic health check endpoint."""
  457.     if snac_model:
  458.         return {"status": "ok", "snac_loaded": True, "device": tts_device}
  459.     else:
  460.         return {"status": "warning", "snac_loaded": False, "message": "SNAC model failed to load."}
  461.  
  462. @app.post("/stream-tts/")
  463. async def stream_tts_endpoint(request: TTSRequest):
  464.     """
  465.    Generates streaming audio for the given text and parameters.
  466.    Streams raw audio bytes (float32, mono, 24kHz).
  467.    """
  468.     if not snac_model:
  469.         raise HTTPException(status_code=503, detail="TTS service unavailable: SNAC model not loaded.")
  470.     if request.voice not in ALL_VOICES:
  471.         raise HTTPException(status_code=400, detail=f"Invalid voice '{request.voice}'. Available: {ALL_VOICES}")
  472.  
  473.     logger.info(f"Received TTS request: voice='{request.voice}', text='{request.text[:50]}...', buffer={request.buffer_groups}, padding={request.padding_ms}")
  474.  
  475.     # Ensure parameters are within reasonable bounds if needed
  476.     request.buffer_groups = max(5, min(request.buffer_groups, 100))
  477.     request.padding_ms = max(0, min(request.padding_ms, 500))
  478.  
  479.     # Create the generator
  480.     audio_generator = generate_speech_stream_bytes(
  481.         text=request.text,
  482.         voice=request.voice,
  483.         tts_temperature=request.temperature,
  484.         tts_top_p=request.top_p,
  485.         tts_repetition_penalty=request.repetition_penalty,
  486.         buffer_groups_param=request.buffer_groups,
  487.         padding_ms_param=request.padding_ms,
  488.     )
  489.  
  490.     # Return a StreamingResponse
  491.     # The client needs to know the sample rate and dtype beforehand.
  492.     return StreamingResponse(
  493.         audio_generator,
  494.         media_type="application/octet-stream" # Sending raw bytes
  495.         # Headers can be used to send metadata if needed, e.g.,
  496.         # headers={"X-Sample-Rate": str(TARGET_SAMPLE_RATE), "X-Audio-Dtype": "float32"}
  497.     )
  498.  
  499. # =============================================================================
  500. # Server Entry Point
  501. # =============================================================================
  502. if __name__ == "__main__":
  503.     server_port = int(os.getenv("TTS_SERVER_PORT", 8001))
  504.     logger.info(f"Starting FastAPI server on 0.0.0.0:{server_port}")
  505.     # Use reload=True for development, False for production
  506.     uvicorn.run(app, host="0.0.0.0", port=server_port, log_level="info")
Tags: ai
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement