Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # =============================================================================
- # Standard Library Imports
- # =============================================================================
- import json
- import logging
- import os
- import re
- import time
- import traceback
- from typing import Any, Dict, Generator, List, Optional, Tuple, Union
- # =============================================================================
- # Third-Party Library Imports
- # =============================================================================
- import numpy as np
- import requests
- import torch
- from torch import nn
- from fastapi import FastAPI, HTTPException, Request
- from fastapi.responses import StreamingResponse
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- import uvicorn # For running the server
- # =============================================================================
- # Library Imports with Error Handling & Conditional Imports
- # =============================================================================
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
- logger = logging.getLogger(__name__)
- try:
- from snac import SNAC
- logger.info("SNAC imported.")
- except ImportError:
- logger.error("SNAC not found. pip install git+https://github.com/hubertsiuzdak/snac.git")
- SNAC = None # Set to None if import fails
- # =============================================================================
- # Configuration Loading
- # =============================================================================
- # --- API Endpoints & Model Names ---
- # Use the TTS_API_ENDPOINT for the upstream service (like LM Studio)
- # UPSTREAM_TTS_API_ENDPOINT = os.getenv("UPSTREAM_TTS_API_ENDPOINT", "http://127.0.0.1:1234/v1/completions")
- UPSTREAM_TTS_API_ENDPOINT = os.getenv("UPSTREAM_TTS_API_ENDPOINT", "http://127.0.0.1:8080/v1/completions")
- # TTS_MODEL = os.getenv("TTS_MODEL", "isaiahbjork/orpheus-3b-0.1-ft")
- TTS_MODEL = os.getenv("TTS_MODEL", "lex-au/Orpheus-3b-FT-Q2_K.gguf")
- # --- Prompts ---
- TTS_PROMPT_FORMAT = "<|audio|>{voice}: {text}<|eot_id|>"
- TTS_PROMPT_STOP_TOKENS = ["<|eot_id|>", "<|audio|>"]
- logger.info(f"Upstream TTS Endpoint: {UPSTREAM_TTS_API_ENDPOINT}, TTS Model: {TTS_MODEL}")
- # =============================================================================
- # Constants
- # =============================================================================
- # --- TTS Default Parameters ---
- DEFAULT_TTS_TEMP = 0.8
- DEFAULT_TTS_TOP_P = 0.9
- DEFAULT_TTS_REP_PENALTY = 1.1
- # --- Orpheus/SNAC Specific Constants ---
- ORPHEUS_MIN_ID = 10
- ORPHEUS_TOKENS_PER_LAYER = 4096
- ORPHEUS_N_LAYERS = 7
- ORPHEUS_MAX_ID = ORPHEUS_MIN_ID + (ORPHEUS_N_LAYERS * ORPHEUS_TOKENS_PER_LAYER)
- # --- Audio Processing & Misc ---
- TARGET_SAMPLE_RATE = 24000 # Crucial for client playback!
- AUDIO_DTYPE = np.float32 # SNAC outputs float32
- # --- Streaming TTS Constants ---
- DEFAULT_TTS_STREAM_MIN_GROUPS = 40
- DEFAULT_TTS_STREAM_SILENCE_MS = 5
- # --- API Communication ---
- API_TIMEOUT_SECONDS = 180
- STREAM_TIMEOUT_SECONDS = 300
- STREAM_HEADERS = {"Content-Type": "application/json", "Accept": "text/event-stream"}
- SSE_DATA_PREFIX = "data:"
- SSE_DONE_MARKER = "[DONE]"
- # --- Voice Constants ---
- ALL_VOICES = ["tara", "jess", "leo", "leah", "dan", "mia", "zac", "zoe"]
- DEFAULT_TTS_VOICE = ALL_VOICES[0]
- # =============================================================================
- # Device Setup
- # =============================================================================
- tts_device = "cuda" if torch.cuda.is_available() else "cpu"
- logger.info(f"TTS Device: '{tts_device}'")
- # =============================================================================
- # Utility Functions (Copied from original script)
- # =============================================================================
- def parse_gguf_codes(response_text: str) -> List[int]:
- """Parse Orpheus <custom_token_ID> from text."""
- try:
- logger.debug(f"PARSING FOR CODES IN: {response_text[:200]}...")
- codes = [
- int(m) for m in re.findall(r"<custom_token_(\d+)>", response_text)
- if ORPHEUS_MIN_ID <= int(m) < ORPHEUS_MAX_ID
- ]
- if codes:
- logger.debug(f"FOUND {len(codes)} CODES: first few = {codes[:10]}...")
- else:
- logger.debug(f"NO CODES FOUND using pattern '<custom_token_(\\d+)>'")
- return codes
- except Exception as e:
- logger.error(f"GGUF parse error: {e}")
- return []
- def redistribute_codes(codes: List[int], model: nn.Module) -> Optional[np.ndarray]:
- """Convert absolute Orpheus token IDs to SNAC input tensors and decode audio."""
- if not codes or model is None:
- return None
- try:
- dev = next(model.parameters()).device
- layers: List[List[int]] = [[], [], []]
- groups = len(codes) // ORPHEUS_N_LAYERS
- if groups == 0:
- return None
- valid = 0
- for i in range(groups):
- idx = i * ORPHEUS_N_LAYERS
- group = codes[idx : idx + ORPHEUS_N_LAYERS]
- processed: List[Optional[int]] = [None] * ORPHEUS_N_LAYERS
- ok = True
- for j, t_id in enumerate(group):
- if not (ORPHEUS_MIN_ID <= t_id < ORPHEUS_MAX_ID):
- ok = False; break
- layer_idx = (t_id - ORPHEUS_MIN_ID) // ORPHEUS_TOKENS_PER_LAYER
- code_idx = (t_id - ORPHEUS_MIN_ID) % ORPHEUS_TOKENS_PER_LAYER
- if layer_idx != j:
- ok = False; break
- processed[j] = code_idx
- if ok:
- try:
- if any(c is None for c in processed): continue
- pg: List[int] = processed
- layers[0].append(pg[0]); layers[1].append(pg[1]); layers[2].append(pg[2])
- layers[2].append(pg[3]); layers[1].append(pg[4]); layers[2].append(pg[5])
- layers[2].append(pg[6]); valid += 1
- except (IndexError, TypeError) as map_e:
- logger.error(f"Code map error in group {i}: {map_e}"); continue
- if valid == 0:
- logger.warning("No valid code groups found after processing.")
- return None
- if not all(layers):
- logger.error("SNAC layers empty after processing valid groups.")
- return None
- tensors = [ torch.tensor(lc, device=dev, dtype=torch.long).unsqueeze(0) for lc in layers ]
- with torch.no_grad():
- audio = model.decode(tensors)
- # Ensure output is float32 numpy array
- return audio.detach().squeeze().cpu().to(torch.float32).numpy()
- except Exception as e:
- logger.exception("SNAC decode error during tensor creation or decoding.")
- return None
- def apply_fade(audio_chunk: np.ndarray, sample_rate: int, fade_ms: int = 3) -> np.ndarray:
- """Apply a short linear fade-in and fade-out to an audio chunk."""
- num_fade_samples = int(sample_rate * (fade_ms / 1000.0))
- if num_fade_samples <= 0 or audio_chunk.size < 3 * num_fade_samples:
- return audio_chunk
- fade_in = np.linspace(0., 1., num_fade_samples, dtype=audio_chunk.dtype)
- fade_out = np.linspace(1., 0., num_fade_samples, dtype=audio_chunk.dtype)
- chunk_copy = audio_chunk.copy() # Ensure we work on a copy
- chunk_copy[:num_fade_samples] *= fade_in
- chunk_copy[-num_fade_samples:] *= fade_out
- return chunk_copy
- # =============================================================================
- # Model Loading
- # =============================================================================
- snac_model: Optional[SNAC] = None
- if SNAC:
- logger.info("--- Loading Local SNAC Model ---")
- try:
- import warnings
- warnings.filterwarnings("ignore", category=FutureWarning, module="snac.snac")
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
- if snac_model:
- snac_model = snac_model.to(tts_device).eval()
- logger.info(f"SNAC loaded to '{tts_device}'.")
- # Optional Warmup
- try:
- logger.info("Attempting SNAC warm-up...")
- dummy_tokens = [
- min(ORPHEUS_MIN_ID + i * ORPHEUS_TOKENS_PER_LAYER + 100, ORPHEUS_MAX_ID - 1)
- for i in range(ORPHEUS_N_LAYERS)
- ] * 10 # Small warmup sequence
- warmup_audio = redistribute_codes(dummy_tokens, snac_model)
- if warmup_audio is not None and warmup_audio.size > 0:
- logger.info(f"SNAC warm-up OK (produced {warmup_audio.size} samples).")
- else:
- logger.warning("SNAC warm-up call ran but produced no audio.")
- except Exception as wu_e:
- logger.exception("SNAC warm-up call failed.")
- else:
- logger.error("SNAC.from_pretrained returned None. Model not loaded.")
- snac_model = None
- except Exception as e:
- logger.exception("Fatal error loading SNAC.")
- snac_model = None
- else:
- logger.critical("SNAC library not found. TTS server cannot function.")
- # Optionally exit here if SNAC is mandatory
- # exit(1)
- if not snac_model:
- logger.critical("SNAC model failed to load. TTS endpoint will return errors.")
- # =============================================================================
- # TTS Pipeline Function (Modified for FastAPI StreamingResponse)
- # =============================================================================
- async def generate_speech_stream_bytes(
- text: str,
- voice: str,
- tts_temperature: float,
- tts_top_p: float,
- tts_repetition_penalty: float,
- buffer_groups_param: int,
- padding_ms_param: int,
- ) -> Generator[bytes, None, None]:
- """Generates audio chunk bytes via TTS streaming API + local SNAC."""
- if not snac_model:
- logger.error("generate_speech_stream_bytes called but snac_model is not loaded.")
- # Yield an empty byte string or handle differently if preferred
- # raise HTTPException(status_code=503, detail="SNAC Vocoder model not available") # Could raise here, but generator might handle it smoother
- yield b'' # Return empty bytes to signal failure gracefully to client?
- return
- if not text.strip():
- logger.warning("generate_speech_stream_bytes called with empty text.")
- yield b''
- return
- min_codes_required = buffer_groups_param * ORPHEUS_N_LAYERS
- logger.debug(f"Stream processing: buffer={buffer_groups_param} groups ({min_codes_required} codes), padding={padding_ms_param} ms")
- silence_samples: int = 0
- if padding_ms_param > 0:
- silence_samples = int(TARGET_SAMPLE_RATE * (padding_ms_param / 1000.0))
- logger.debug(f"Calculated silence samples per side: {silence_samples}")
- payload = {
- "model": TTS_MODEL,
- "prompt": TTS_PROMPT_FORMAT.format(voice=voice, text=text),
- "temperature": tts_temperature,
- "top_p": tts_top_p,
- "repeat_penalty": tts_repetition_penalty,
- "n_predict": -1, # Predict until stop token or model limit
- "stop": TTS_PROMPT_STOP_TOKENS,
- "stream": True
- }
- accumulated_codes: List[int] = []
- response = None
- stream_start_time = time.time()
- any_audio_yielded = False
- try:
- logger.info(">>> TTS API: Initiating stream request to upstream...")
- # Using requests synchronously within an async function is generally okay for I/O bound tasks
- # but for high-performance scenarios, consider an async HTTP client like httpx.
- with requests.post(
- UPSTREAM_TTS_API_ENDPOINT, json=payload, headers=STREAM_HEADERS, stream=True, timeout=STREAM_TIMEOUT_SECONDS
- ) as response:
- response.raise_for_status()
- logger.info(f"--- TTS API: Upstream stream connected after {time.time() - stream_start_time:.3f}s. Receiving codes...")
- for line in response.iter_lines():
- if not line: continue
- try:
- decoded_line = line.decode(response.encoding or 'utf-8')
- logger.debug(f"RAW LINE: {decoded_line[:200]}...") # Log the raw SSE line
- except UnicodeDecodeError:
- logger.warning(f"Skipping undecodable line: {line[:50]}...")
- continue
- if decoded_line.startswith(SSE_DATA_PREFIX):
- json_str = decoded_line[len(SSE_DATA_PREFIX):].strip()
- if json_str == SSE_DONE_MARKER:
- logger.debug("Received TTS SSE_DONE_MARKER from upstream.")
- break
- if not json_str: continue
- try:
- data = json.loads(json_str)
- logger.debug(f"PARSED JSON STRUCTURE: {json.dumps(data, indent=2)[:500]}...")
- # chunk_text = ""
- # print("data is ", data)
- # # Adapt parsing based on expected upstream API response structure
- # if "content" in data:
- # chunk_text = data.get("content", "")
- # elif "choices" in data and data["choices"]:
- # choice = data["choices"][0]
- # delta = choice.get("delta", {})
- # chunk_text = delta.get("content", "") or choice.get("text", "") # Handle different possible keys
- chunk_text = ""
- try:
- choices = data.get("choices", [])
- if choices:
- first_choice = choices[0]
- chunk_text = (
- first_choice.get("delta", {}).get("content")
- or first_choice.get("text")
- or ""
- )
- except (AttributeError, IndexError, TypeError):
- chunk_text = ""
- if chunk_text:
- new_codes = parse_gguf_codes(chunk_text)
- if new_codes:
- accumulated_codes.extend(new_codes)
- # Process codes if enough are accumulated
- if len(accumulated_codes) >= min_codes_required:
- num_groups_to_decode = len(accumulated_codes) // ORPHEUS_N_LAYERS
- codes_to_decode = accumulated_codes[:num_groups_to_decode * ORPHEUS_N_LAYERS]
- accumulated_codes = accumulated_codes[num_groups_to_decode * ORPHEUS_N_LAYERS:]
- snac_start_time = time.time()
- audio_chunk = redistribute_codes(codes_to_decode, snac_model)
- snac_end_time = time.time()
- if audio_chunk is not None and audio_chunk.size > 0:
- logger.debug(f"--- SNAC: Decoded chunk ({len(codes_to_decode)} codes -> {audio_chunk.size} samples) in {snac_end_time - snac_start_time:.3f}s.")
- faded_chunk = apply_fade(audio_chunk, TARGET_SAMPLE_RATE, fade_ms=3)
- # Add padding if needed
- if silence_samples > 0:
- silence = np.zeros(silence_samples, dtype=AUDIO_DTYPE)
- final_chunk = np.concatenate((silence, faded_chunk, silence))
- else:
- final_chunk = faded_chunk
- yield final_chunk.tobytes()
- any_audio_yielded = True
- else:
- logger.warning(f"--- SNAC: Failed to decode chunk ({len(codes_to_decode)} codes) in {snac_end_time - snac_start_time:.3f}s.")
- # Check for stop conditions from upstream API
- stop_reason=None
- is_stopped=False
- if "choices" in data and data["choices"]:
- stop_reason=data["choices"][0].get("finish_reason")
- # Check common stop flags from llama.cpp server
- if stop_reason or data.get("stop") or data.get("stopped_eos") or data.get("stopped_limit"):
- is_stopped=True
- logger.debug(f"Upstream TTS Stream stop condition met: reason='{stop_reason}', data flags: stop={data.get('stop')}, eos={data.get('stopped_eos')}, limit={data.get('stopped_limit')}")
- if is_stopped:
- break
- except json.JSONDecodeError:
- logger.warning(f"Skipping invalid JSON in upstream TTS stream: {json_str[:100]}...")
- continue
- except Exception as e:
- logger.exception(f"Error processing upstream TTS stream chunk: {json_str[:100]}...")
- continue # Or break depending on desired behavior
- # Process any remaining codes after the loop
- if len(accumulated_codes) >= ORPHEUS_N_LAYERS:
- logger.debug(f"Processing final {len(accumulated_codes)} codes after stream end.")
- num_groups = len(accumulated_codes) // ORPHEUS_N_LAYERS
- codes_to_decode = accumulated_codes[:num_groups * ORPHEUS_N_LAYERS]
- snac_start_time = time.time()
- audio_chunk = redistribute_codes(codes_to_decode, snac_model)
- snac_end_time = time.time()
- if audio_chunk is not None and audio_chunk.size > 0:
- logger.debug(f"--- SNAC: Decoded final chunk ({len(codes_to_decode)} codes -> {audio_chunk.size} samples) in {snac_end_time - snac_start_time:.3f}s.")
- faded_chunk = apply_fade(audio_chunk, TARGET_SAMPLE_RATE, fade_ms=3)
- if silence_samples > 0:
- silence = np.zeros(silence_samples, dtype=AUDIO_DTYPE)
- final_chunk = np.concatenate((silence, faded_chunk, silence))
- else:
- final_chunk = faded_chunk
- yield final_chunk.tobytes()
- any_audio_yielded = True
- else:
- logger.warning(f"--- SNAC: Failed to decode final chunk ({len(codes_to_decode)} codes) in {snac_end_time - snac_start_time:.3f}s.")
- elif accumulated_codes:
- logger.debug(f"Discarding final {len(accumulated_codes)} codes (less than {ORPHEUS_N_LAYERS}).")
- except requests.exceptions.RequestException as e:
- logger.exception(f"<<< TTS API: Upstream RequestException after {time.time() - stream_start_time:.3f}s.")
- # Optionally yield an error indicator or raise within the generator context if FastAPI handles it
- # yield b'ERROR: Upstream connection failed' # Example of sending text error
- except Exception as e:
- logger.exception(f"<<< TTS API: Unexpected error during upstream streaming after {time.time() - stream_start_time:.3f}s.")
- # yield b'ERROR: Internal server error'
- finally:
- if not any_audio_yielded:
- logger.warning(f"<<< TTS API: Stream finished but NO audio was generated/yielded.")
- logger.info(f"<<< TTS API: Upstream stream processing finished after {time.time() - stream_start_time:.3f}s.")
- # =============================================================================
- # FastAPI Application
- # =============================================================================
- app = FastAPI(title="Streaming Orpheus TTS Server", version="1.0.0")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["http://localhost:3000"], # Your React app's URL (adjust port if needed)
- allow_credentials=True,
- allow_methods=["*"], # Allow all methods or specify: ["GET", "POST", etc.]
- allow_headers=["*"], # Allow all headers
- )
- class TTSRequest(BaseModel):
- text: str
- voice: str = DEFAULT_TTS_VOICE
- temperature: float = DEFAULT_TTS_TEMP
- top_p: float = DEFAULT_TTS_TOP_P
- repetition_penalty: float = DEFAULT_TTS_REP_PENALTY
- buffer_groups: int = DEFAULT_TTS_STREAM_MIN_GROUPS
- padding_ms: int = DEFAULT_TTS_STREAM_SILENCE_MS
- @app.on_event("startup")
- async def startup_event():
- logger.info("Server starting up.")
- if not snac_model:
- logger.warning("SNAC model is not loaded. /stream-tts endpoint may fail.")
- @app.on_event("shutdown")
- def shutdown_event():
- logger.info("Server shutting down.")
- @app.get("/health")
- async def health_check():
- """Basic health check endpoint."""
- if snac_model:
- return {"status": "ok", "snac_loaded": True, "device": tts_device}
- else:
- return {"status": "warning", "snac_loaded": False, "message": "SNAC model failed to load."}
- @app.post("/stream-tts/")
- async def stream_tts_endpoint(request: TTSRequest):
- """
- Generates streaming audio for the given text and parameters.
- Streams raw audio bytes (float32, mono, 24kHz).
- """
- if not snac_model:
- raise HTTPException(status_code=503, detail="TTS service unavailable: SNAC model not loaded.")
- if request.voice not in ALL_VOICES:
- raise HTTPException(status_code=400, detail=f"Invalid voice '{request.voice}'. Available: {ALL_VOICES}")
- logger.info(f"Received TTS request: voice='{request.voice}', text='{request.text[:50]}...', buffer={request.buffer_groups}, padding={request.padding_ms}")
- # Ensure parameters are within reasonable bounds if needed
- request.buffer_groups = max(5, min(request.buffer_groups, 100))
- request.padding_ms = max(0, min(request.padding_ms, 500))
- # Create the generator
- audio_generator = generate_speech_stream_bytes(
- text=request.text,
- voice=request.voice,
- tts_temperature=request.temperature,
- tts_top_p=request.top_p,
- tts_repetition_penalty=request.repetition_penalty,
- buffer_groups_param=request.buffer_groups,
- padding_ms_param=request.padding_ms,
- )
- # Return a StreamingResponse
- # The client needs to know the sample rate and dtype beforehand.
- return StreamingResponse(
- audio_generator,
- media_type="application/octet-stream" # Sending raw bytes
- # Headers can be used to send metadata if needed, e.g.,
- # headers={"X-Sample-Rate": str(TARGET_SAMPLE_RATE), "X-Audio-Dtype": "float32"}
- )
- # =============================================================================
- # Server Entry Point
- # =============================================================================
- if __name__ == "__main__":
- server_port = int(os.getenv("TTS_SERVER_PORT", 8001))
- logger.info(f"Starting FastAPI server on 0.0.0.0:{server_port}")
- # Use reload=True for development, False for production
- uvicorn.run(app, host="0.0.0.0", port=server_port, log_level="info")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement