Guest User

frontend2

a guest
May 17th, 2025
6
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 12.38 KB | None | 0 0
  1. from flask import Flask, render_template, request, jsonify
  2. import requests
  3. import os
  4. import soundfile as sf
  5. import numpy as np
  6. from kokoro import KPipeline # Your existing Kokoro TTS
  7. import uuid
  8. import base64
  9. import io # For handling in-memory audio
  10.  
  11. # Attempt to import NeMo ASR
  12. try:
  13. import nemo.collections.asr as nemo_asr
  14. print("NVIDIA NeMo ASR Toolkit imported successfully.")
  15. except ImportError:
  16. print("NVIDIA NeMo ASR Toolkit not found. Please install it: pip install nemo_toolkit['asr']")
  17. nemo_asr = None
  18. except Exception as e:
  19. print(f"Error importing NeMo ASR: {e}")
  20. nemo_asr = None
  21.  
  22.  
  23. app = Flask(__name__)
  24.  
  25. # --- Configuration ---
  26. OLLAMA_API_URL = "http://localhost:11434/api/generate"
  27. OLLAMA_TAGS_URL = "http://localhost:11434/api/tags"
  28. MODEL_DIR = "models" # General models directory
  29. ASR_MODEL_DIR = os.path.join(MODEL_DIR, "asr_models") # Specific for ASR
  30. KOKORO_MODEL_PATH = os.path.join(MODEL_DIR, "kokoro-v1.0.onnx")
  31. KOKORO_VOICES_PATH = os.path.join(MODEL_DIR, "voices-v1.0.bin")
  32. KOKORO_LANG_CODE = "a"
  33. VOICE = "af_bella"
  34.  
  35. # Parakeet ASR model configuration
  36. ASR_MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
  37.  
  38. # --- Initialize Services ---
  39. os.makedirs(MODEL_DIR, exist_ok=True)
  40. os.makedirs(ASR_MODEL_DIR, exist_ok=True) # Ensure ASR model directory exists
  41. os.makedirs("static", exist_ok=True) # For temporary audio files if needed
  42.  
  43. # Initialize Kokoro TTS
  44. tts_pipeline = None
  45. try:
  46. # Pass repo_id explicitly to suppress the warning if you know it, e.g. from hexgrad for Kokoro
  47. # For Kokoro, if it's from hexgrad/Kokoro-82M, you could try passing repo_id='hexgrad/Kokoro-82M'
  48. # However, KPipeline itself might not accept repo_id. This warning is from underlying Hugging Face libs Kokoro might use.
  49. # For now, we'll let the warning be, as Kokoro seems to initialize.
  50. if os.path.exists(KOKORO_MODEL_PATH) and os.path.exists(KOKORO_VOICES_PATH):
  51. tts_pipeline = KPipeline(lang_code=KOKORO_LANG_CODE) # repo_id='hexgrad/Kokoro-82M' # example if KPipeline supported it
  52. print("Kokoro TTS initialized successfully.")
  53. else:
  54. print("Kokoro TTS model/voice files not found. Skipping initialization.")
  55. except Exception as e:
  56. print(f"Error initializing Kokoro TTS: {str(e)}")
  57.  
  58. # Initialize NeMo ASR Model
  59. asr_model_instance = None
  60. if nemo_asr:
  61. try:
  62. print(f"Loading ASR model: {ASR_MODEL_NAME}...")
  63. asr_model_instance = nemo_asr.models.ASRModel.from_pretrained(
  64. model_name=ASR_MODEL_NAME,
  65. map_location='cpu' # Use 'cuda' if you have a GPU and CUDA installed
  66. )
  67. asr_model_instance.eval() # Set to evaluation mode
  68. print(f"ASR model '{ASR_MODEL_NAME}' loaded successfully.")
  69. except Exception as e:
  70. print(f"Error loading ASR model '{ASR_MODEL_NAME}': {str(e)}")
  71. print("ASR will not be available.")
  72. else:
  73. print("NeMo ASR toolkit not available. ASR functionality will be disabled.")
  74.  
  75.  
  76. def download_kokoro_model():
  77. if not os.path.exists(KOKORO_MODEL_PATH) or not os.path.exists(KOKORO_VOICES_PATH):
  78. print("Attempting to download Kokoro TTS model files...")
  79. try:
  80. model_url = "https://github.com/nazdridoy/kokoro-tts/releases/download/v1.0.0/kokoro-v1.0.onnx"
  81. voices_url = "https://github.com/nazdridoy/kokoro-tts/releases/download/v1.0.0/voices-v1.0.bin"
  82. for url, path in [(model_url, KOKORO_MODEL_PATH), (voices_url, KOKORO_VOICES_PATH)]:
  83. if not os.path.exists(path):
  84. print(f"Downloading {url} to {path}")
  85. response = requests.get(url, stream=True)
  86. response.raise_for_status()
  87. with open(path, "wb") as f:
  88. for chunk in response.iter_content(chunk_size=8192):
  89. f.write(chunk)
  90. global tts_pipeline # Declare tts_pipeline as global to modify it
  91. if not tts_pipeline and os.path.exists(KOKORO_MODEL_PATH) and os.path.exists(KOKORO_VOICES_PATH):
  92. tts_pipeline = KPipeline(lang_code=KOKORO_LANG_CODE)
  93. print("Kokoro TTS initialized successfully after download.")
  94.  
  95. except Exception as e:
  96. print(f"Error downloading Kokoro model: {str(e)}")
  97.  
  98. download_kokoro_model()
  99.  
  100. def get_ollama_models():
  101. try:
  102. response = requests.get(OLLAMA_TAGS_URL)
  103. response.raise_for_status()
  104. models_data = response.json().get("models", [])
  105. return [model["name"] for model in models_data]
  106. except requests.RequestException as e:
  107. print(f"Error fetching Ollama models: {str(e)}")
  108. return ["llama2:latest"]
  109.  
  110. @app.route('/')
  111. def index():
  112. models = get_ollama_models()
  113. return render_template('index2.html', models=models)
  114.  
  115. @app.route('/process_voice_input', methods=['POST'])
  116. def process_voice_input():
  117. if 'audio_data' not in request.files:
  118. return jsonify({"error": "No audio data found in request"}), 400
  119. if not asr_model_instance:
  120. return jsonify({"error": "ASR model not available on server"}), 500
  121.  
  122. audio_file = request.files['audio_data']
  123.  
  124. temp_audio_filename = f"temp_audio_{uuid.uuid4()}.wav"
  125. temp_audio_path = os.path.join("static", temp_audio_filename)
  126.  
  127. try:
  128. audio_file.save(temp_audio_path)
  129. print(f"Temporary audio file saved to: {temp_audio_path}")
  130.  
  131. # Transcribe audio using NeMo ASR model
  132. transcription_results = asr_model_instance.transcribe([temp_audio_path])
  133.  
  134. user_input_text = "" # Initialize to empty
  135.  
  136. if transcription_results:
  137. # NeMo's transcribe can return:
  138. # 1. A list of strings (if not using N-best, simple case)
  139. # 2. A list of lists of Hypothesis objects (if using N-best, even with N=1)
  140. # Based on your log: [Hypothesis(text="Hey, what's up?", ...)]
  141. # This means transcription_results is a list, and its first element is the Hypothesis object.
  142.  
  143. first_file_result = transcription_results[0] # Result for the first (and only) file
  144.  
  145. if isinstance(first_file_result, str):
  146. user_input_text = first_file_result
  147. elif isinstance(first_file_result, list) and len(first_file_result) > 0:
  148. # This would be for N-best lists, where first_file_result is a list of Hypothesis objects
  149. if hasattr(first_file_result[0], 'text'):
  150. user_input_text = first_file_result[0].text # Get text from the top hypothesis
  151. else:
  152. print(f"Warning: Top hypothesis object in N-best list lacks 'text' attribute: {first_file_result[0]}")
  153. elif hasattr(first_file_result, 'text'):
  154. # This covers the case where first_file_result is a single Hypothesis object
  155. # (e.g., when transcribe returns a list of Hypothesis objects, one per input file)
  156. user_input_text = first_file_result.text
  157. else:
  158. print(f"Warning: Transcription result format for the file was unexpected: {first_file_result}")
  159. else:
  160. print("Warning: ASR returned no transcription results at all.")
  161.  
  162. print(f"Transcribed text: '{user_input_text}'")
  163.  
  164. except Exception as e:
  165. print(f"Error during ASR transcription: {str(e)}")
  166. return jsonify({"error": f"ASR transcription error: {str(e)}"}), 500
  167. finally:
  168. if os.path.exists(temp_audio_path):
  169. try:
  170. os.remove(temp_audio_path)
  171. print(f"Temporary audio file {temp_audio_path} removed.")
  172. except Exception as e_remove:
  173. print(f"Error removing temporary audio file {temp_audio_path}: {e_remove}")
  174.  
  175. if not user_input_text.strip():
  176. return jsonify({
  177. "transcribed_text": user_input_text,
  178. "text": "Could not understand audio or audio was silent.",
  179. "audio": None
  180. })
  181.  
  182. selected_model = request.form.get('model', get_ollama_models()[0])
  183. system_prompt = request.form.get('system_prompt', "You are a helpful, friendly AI assistant.")
  184.  
  185. available_models = get_ollama_models()
  186. if selected_model not in available_models:
  187. selected_model = available_models[0] if available_models else "llama2:latest"
  188.  
  189. try:
  190. ollama_payload = {
  191. "model": selected_model,
  192. "prompt": user_input_text,
  193. "system": system_prompt,
  194. "stream": False
  195. }
  196. print(f"Sending to Ollama: {ollama_payload}")
  197. ollama_response = requests.post(OLLAMA_API_URL, json=ollama_payload)
  198. ollama_response.raise_for_status()
  199. llm_response_text = ollama_response.json().get("response", "")
  200. print(f"Ollama response: '{llm_response_text}'")
  201. except requests.RequestException as e:
  202. print(f"Ollama API error: {str(e)}")
  203. return jsonify({"error": f"Ollama API error: {str(e)}", "transcribed_text": user_input_text}), 500
  204.  
  205. tts_audio_base64 = None
  206. if tts_pipeline and llm_response_text:
  207. try:
  208. generator = tts_pipeline(llm_response_text, voice=VOICE)
  209. audio_chunks = [audio for _, _, audio in generator]
  210. if audio_chunks:
  211. final_audio_np = np.concatenate(audio_chunks)
  212. wav_buffer = io.BytesIO()
  213. sf.write(wav_buffer, final_audio_np, 24000, format='WAV', subtype='PCM_16')
  214. wav_buffer.seek(0)
  215. tts_audio_base64 = base64.b64encode(wav_buffer.read()).decode('utf-8')
  216. print("Kokoro TTS audio generated.")
  217. else:
  218. print("Kokoro TTS produced no audio chunks.")
  219. except Exception as e:
  220. print(f"Kokoro TTS generation error: {str(e)}")
  221. llm_response_text += " (TTS Error)"
  222. elif not tts_pipeline:
  223. print("Kokoro TTS pipeline not available.")
  224. elif not llm_response_text:
  225. print("LLM response was empty, skipping TTS.")
  226.  
  227. return jsonify({
  228. "transcribed_text": user_input_text,
  229. "text": llm_response_text,
  230. "audio": tts_audio_base64
  231. })
  232.  
  233. @app.route('/process_typed_text', methods=['POST'])
  234. def process_typed_text():
  235. data = request.json
  236. user_input_text = data.get('text')
  237. selected_model = data.get('model')
  238. system_prompt = data.get('system_prompt', "You are a helpful, friendly AI assistant.")
  239.  
  240. if not user_input_text:
  241. return jsonify({"error": "No input text provided"}), 400
  242.  
  243. available_models = get_ollama_models()
  244. if selected_model not in available_models:
  245. selected_model = available_models[0] if available_models else "llama2:latest"
  246.  
  247. try:
  248. ollama_payload = {
  249. "model": selected_model,
  250. "prompt": user_input_text,
  251. "system": system_prompt,
  252. "stream": False
  253. }
  254. print(f"Sending typed text to Ollama: {ollama_payload}")
  255. ollama_response = requests.post(OLLAMA_API_URL, json=ollama_payload)
  256. ollama_response.raise_for_status()
  257. llm_response_text = ollama_response.json().get("response", "")
  258. print(f"Ollama response to typed text: '{llm_response_text}'")
  259. except requests.RequestException as e:
  260. print(f"Ollama API error (typed text): {str(e)}")
  261. return jsonify({"error": f"Ollama API error: {str(e)}"}), 500
  262.  
  263. tts_audio_base64 = None
  264. if tts_pipeline and llm_response_text:
  265. try:
  266. generator = tts_pipeline(llm_response_text, voice=VOICE)
  267. audio_chunks = [audio for _, _, audio in generator]
  268. if audio_chunks:
  269. final_audio_np = np.concatenate(audio_chunks)
  270. wav_buffer = io.BytesIO()
  271. sf.write(wav_buffer, final_audio_np, 24000, format='WAV', subtype='PCM_16')
  272. wav_buffer.seek(0)
  273. tts_audio_base64 = base64.b64encode(wav_buffer.read()).decode('utf-8')
  274. print("Kokoro TTS for typed text generated.")
  275. except Exception as e:
  276. print(f"Kokoro TTS generation error (typed text): {str(e)}")
  277. llm_response_text += " (TTS Error)"
  278.  
  279. return jsonify({
  280. "transcribed_text": None,
  281. "text": llm_response_text,
  282. "audio": tts_audio_base64
  283. })
  284.  
  285. if __name__ == '__main__':
  286. app.run(host='0.0.0.0', port=5000, debug=True)
Add Comment
Please, Sign In to add comment