Guest User

f5

a guest
Oct 13th, 2024
131
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 11.67 KB | None | 0 0
  1. print("WARNING: You are running this unofficial E2/F5 TTS demo locally, it may not be as up-to-date as the hosted version (https://huggingface.co/spaces/mrfakename/E2-F5-TTS)")
  2.  
  3. import os
  4. import re
  5. import torch
  6. import torchaudio
  7. import gradio as gr
  8. import numpy as np
  9. import tempfile
  10. from einops import rearrange
  11. from ema_pytorch import EMA
  12. from vocos import Vocos
  13. from pydub import AudioSegment
  14. from model import CFM, UNetT, DiT, MMDiT
  15. from cached_path import cached_path
  16. from model.utils import (
  17. get_tokenizer,
  18. convert_char_to_pinyin,
  19. save_spectrogram,
  20. )
  21. from transformers import pipeline
  22. import librosa
  23. import re
  24. import gc
  25. import matplotlib.pyplot as plt
  26. from safetensors.torch import load_file
  27.  
  28. device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
  29.  
  30. gc.collect()
  31. torch.cuda.empty_cache()
  32.  
  33. print(f"Using {device} device")
  34.  
  35.  
  36. # --------------------- Settings -------------------- #
  37.  
  38. target_sample_rate = 24000
  39. n_mel_channels = 100
  40. hop_length = 256
  41. target_rms = 0.1
  42. nfe_step = 32 # 16, 32
  43. cfg_strength = 2.0
  44. ode_method = 'euler'
  45. sway_sampling_coef = -1.0
  46. speed = 1.0
  47. # fix_duration = 27 # None or float (duration in seconds)
  48. fix_duration = None
  49.  
  50. def load_model(exp_name, model_cls, model_cfg, ckpt_step):
  51. checkpoint = load_file(str(cached_path(f"/path/F5TTS/{exp_name}/model_{ckpt_step}.safetensors")))
  52. #print(checkpoint.keys())
  53. vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
  54. model = CFM(
  55. transformer=model_cls(
  56. **model_cfg,
  57. text_num_embeds=vocab_size,
  58. mel_dim=n_mel_channels
  59. ),
  60. mel_spec_kwargs=dict(
  61. target_sample_rate=target_sample_rate,
  62. n_mel_channels=n_mel_channels,
  63. hop_length=hop_length,
  64. ),
  65. odeint_kwargs=dict(
  66. method=ode_method,
  67. ),
  68. vocab_char_map=vocab_char_map,
  69. ).to(device)
  70.  
  71. ema_state_dict = {}
  72. for key, value in checkpoint.items():
  73. if key.startswith('ema_model.'):
  74. ema_state_dict[key[len('ema_model.'):]] = value
  75. model.load_state_dict(ema_state_dict)
  76.  
  77. ema_model = EMA(model, include_online_model=False).to(device)
  78. #ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
  79. #ema_model.copy_params_from_ema_to_model()
  80.  
  81. return ema_model, model
  82.  
  83. # load models
  84. F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
  85. E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
  86.  
  87. F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
  88. E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
  89.  
  90. def chunk_text(text, max_chars=200):
  91. chunks = []
  92. current_chunk = ""
  93. sentences = re.split(r'(?<=[.!?])\s+', text)
  94.  
  95. for sentence in sentences:
  96. if len(current_chunk) + len(sentence) <= max_chars:
  97. current_chunk += sentence + " "
  98. else:
  99. if current_chunk:
  100. chunks.append(current_chunk.strip())
  101. current_chunk = sentence + " "
  102.  
  103. if current_chunk:
  104. chunks.append(current_chunk.strip())
  105.  
  106. return chunks
  107.  
  108. def save_spectrogram(y, sr, path):
  109. plt.figure(figsize=(10, 4))
  110. D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max)
  111. librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz')
  112. plt.colorbar(format='%+2.0f dB')
  113. plt.title('Spectrogram')
  114. plt.tight_layout()
  115. plt.savefig(path)
  116. plt.close()
  117.  
  118. def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
  119. print(gen_text)
  120. chunks = chunk_text(gen_text)
  121.  
  122. if not chunks:
  123. raise gr.Error("Please enter some text to generate.")
  124.  
  125. # Convert reference audio
  126. gr.Info("Converting reference audio...")
  127. with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
  128. aseg = AudioSegment.from_file(ref_audio_orig)
  129. audio_duration = len(aseg)
  130. if audio_duration > 15000:
  131. gr.Warning("Audio is over 15s, clipping to only first 15s.")
  132. aseg = aseg[:15000]
  133. aseg.export(f.name, format="wav")
  134. ref_audio = f.name
  135.  
  136. # Select model
  137. if exp_name == "F5-TTS":
  138. ema_model = F5TTS_ema_model
  139. base_model = F5TTS_base_model
  140. elif exp_name == "E2-TTS":
  141. ema_model = E2TTS_ema_model
  142. base_model = E2TTS_base_model
  143.  
  144. # Transcribe reference audio if needed
  145. if not ref_text.strip():
  146. gr.Info("No reference text provided, transcribing reference audio...")
  147. # Initialize Whisper model
  148. pipe = pipeline(
  149. "automatic-speech-recognition",
  150. model="openai/whisper-large-v3-Turbo", # You can set this to large-V3 if you want better quality, but VRAM then goes to 10 GB
  151. torch_dtype=torch.float16,
  152. device=device,
  153. )
  154. ref_text = pipe(
  155. ref_audio,
  156. chunk_length_s=30,
  157. batch_size=128,
  158. generate_kwargs={"task": "transcribe"},
  159. return_timestamps=False,
  160. )['text'].strip()
  161. print("\nTranscribed text: ", ref_text) # Degug transcribing quality
  162. gr.Info("\nFinished transcription")
  163. # Release Whisper model
  164. del pipe
  165. torch.cuda.empty_cache()
  166. gc.collect()
  167. else:
  168. gr.Info("Using custom reference text...")
  169.  
  170. # Load and preprocess reference audio
  171. audio, sr = torchaudio.load(ref_audio)
  172. if audio.shape[0] > 1:
  173. audio = torch.mean(audio, dim=0, keepdim=True) # convert to mono
  174. rms = torch.sqrt(torch.mean(torch.square(audio)))
  175. if rms < target_rms:
  176. audio = audio * target_rms / rms
  177. if sr != target_sample_rate:
  178. resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
  179. audio = resampler(audio)
  180. audio = audio.to(device)
  181.  
  182. # Process each chunk
  183. results = []
  184. spectrograms = []
  185.  
  186. for i, chunk in enumerate(chunks):
  187. gr.Info(f"Processing chunk {i+1}/{len(chunks)}: {chunk[:30]}...")
  188.  
  189. # Prepare the text
  190. text_list = [ref_text + chunk]
  191. final_text_list = convert_char_to_pinyin(text_list)
  192.  
  193. # Calculate duration
  194. ref_audio_len = audio.shape[-1] // hop_length
  195. zh_pause_punc = r"。,、;:?!"
  196. ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
  197. gen_text_len = len(chunk) + len(re.findall(zh_pause_punc, chunk))
  198. duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
  199.  
  200. # Inference
  201. gr.Info(f"Generating audio using {exp_name}")
  202. with torch.inference_mode():
  203. generated, _ = base_model.sample(
  204. cond=audio,
  205. text=final_text_list,
  206. duration=duration,
  207. steps=nfe_step,
  208. cfg_strength=cfg_strength,
  209. sway_sampling_coef=sway_sampling_coef,
  210. )
  211.  
  212. generated = generated[:, ref_audio_len:, :]
  213. generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
  214.  
  215. # Clear unnecessary tensors
  216. del generated
  217. torch.cuda.empty_cache()
  218.  
  219. gr.Info("Running vocoder")
  220. vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
  221. generated_wave = vocos.decode(generated_mel_spec.cpu())
  222. if rms < target_rms:
  223. generated_wave = generated_wave * rms / target_rms
  224.  
  225. # Convert to numpy and clear GPU tensors
  226. generated_wave = generated_wave.squeeze().cpu().numpy()
  227. del generated_mel_spec
  228. torch.cuda.empty_cache()
  229.  
  230. results.append(generated_wave)
  231.  
  232. # Generate spectrogram
  233. #with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
  234. # spectrogram_path = tmp_spectrogram.name
  235. # save_spectrogram(generated_wave, target_sample_rate, spectrogram_path)
  236. #spectrograms.append(spectrogram_path)
  237.  
  238. # Clear cache after processing each chunk
  239. gc.collect()
  240. torch.cuda.empty_cache()
  241.  
  242. # Combine all audio chunks
  243. combined_audio = np.concatenate(results)
  244.  
  245. if remove_silence:
  246. gr.Info("Removing audio silences... This may take a moment")
  247. non_silent_intervals = librosa.effects.split(combined_audio, top_db=30)
  248. non_silent_wave = np.array([])
  249. for interval in non_silent_intervals:
  250. start, end = interval
  251. non_silent_wave = np.concatenate([non_silent_wave, combined_audio[start:end]])
  252. combined_audio = non_silent_wave
  253.  
  254. # Generate final spectrogram
  255. #with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
  256. # final_spectrogram_path = tmp_spectrogram.name
  257. # save_spectrogram(combined_audio, target_sample_rate, final_spectrogram_path)
  258.  
  259. # Final cleanup
  260. gc.collect()
  261. torch.cuda.empty_cache()
  262.  
  263. # Return combined audio and the final spectrogram
  264. return (target_sample_rate, combined_audio), ref_text
  265.  
  266. with gr.Blocks() as app:
  267. gr.Markdown("""
  268. # E2/F5 TTS
  269.  
  270. This is an unofficial E2/F5 TTS demo. This demo supports the following TTS models:
  271.  
  272. * [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
  273. * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
  274.  
  275. This demo is based on the [F5-TTS](https://github.com/SWivid/F5-TTS) codebase, which is based on an [unofficial E2-TTS implementation](https://github.com/lucidrains/e2-tts-pytorch).
  276.  
  277. The checkpoints support English and Chinese.
  278.  
  279. If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. If you're still running into issues, please open a [community Discussion](https://huggingface.co/spaces/mrfakename/E2-F5-TTS/discussions).
  280.  
  281. **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
  282. """)
  283. ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
  284. gen_text_input = gr.Textbox(label="Text to Generate (for longer than 200 chars the app uses chunking)", lines=4)
  285. model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
  286. generate_btn = gr.Button("Synthesize", variant="primary")
  287. with gr.Accordion("Advanced Settings", open=False):
  288. ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
  289. remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
  290.  
  291. audio_output = gr.Audio(label="Synthesized Audio")
  292. #spectrogram_output = gr.Image(label="Spectrogram")
  293.  
  294. ref_audio_input.upload(
  295. fn=lambda ref_audio: (ref_audio, "" if ref_audio else None ), # Pass audio through, clear ref_text
  296. inputs=ref_audio_input,
  297. outputs=[ref_audio_input, ref_text_input],
  298. )
  299.  
  300. generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output, ref_text_input])
  301. gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
  302.  
  303.  
  304. app.queue().launch(server_name="0.0.0.0", server_port=7860)
Add Comment
Please, Sign In to add comment