CookieAnon

diffsvc_inference.py

Jun 20th, 2021 (edited)
660
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 15.29 KB | None | 0 0
  1. print("loading!")
  2.  
  3. # imports
  4. import torch
  5. import torch.utils.data
  6. import torch.nn.functional as F
  7. import time
  8. import os
  9.  
  10. if True:# dataloader only imports
  11.     import librosa
  12.     from scipy.io.wavfile import write
  13.     from scipy.signal import butter, sosfilt
  14.     import pyworld as pw
  15.     import numpy as np
  16.     import difflib
  17.    
  18.     import pyloudnorm as pyln
  19.     import CookieTTS.utils.audio.stft as STFT
  20.     from CookieTTS.utils.dataset.utils import load_wav_to_torch
  21.  
  22. if True:# model only imports
  23.     from CookieTTS.experiments.DiffSVC.model     import load_model as init_model_diffsvc
  24.     from CookieTTS.experiments.dilated_ASR.model import load_model as init_model_dilated_asr
  25.     from CookieTTS._4_mtw.hifigan_ct.model import load_generator_from_path as load_model_hifigan
  26.  
  27. def get_stft(config):
  28.     stft = STFT.TacotronSTFT(config.filter_length, config.hop_length, config.win_length,
  29.                                   config.n_mel_channels, config.sampling_rate, config.mel_fmin,
  30.                                   config.mel_fmax, clamp_val=config.stft_clamp_val)
  31.     return stft
  32.  
  33. def load_diffsvc_from_path(checkpoint_path, device='cuda'):
  34.     checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  35.     model = init_model_diffsvc(checkpoint_dict['h'])
  36.     model.load_state_dict(checkpoint_dict['state_dict'])
  37.     model.to(device).eval()
  38.     config       = checkpoint_dict['h']
  39.     speaker_list = checkpoint_dict['speakerlist']
  40.     spkr_f0      = checkpoint_dict['speaker_f0_meanstd']
  41.     spkr_sylps   = checkpoint_dict['speaker_sylps_meanstd']
  42.     return model, config, speaker_list, spkr_f0, spkr_sylps
  43.  
  44. def load_dilated_asr_from_path(checkpoint_path, device='cuda'):
  45.     checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  46.     model = init_model_dilated_asr(checkpoint_dict['h'])
  47.     model.load_state_dict(checkpoint_dict['state_dict'])
  48.     model.to(device).eval()
  49.     config       = checkpoint_dict['h']
  50.     speaker_list = checkpoint_dict['speakerlist']
  51.     return model, config, speaker_list
  52.  
  53. def load_hifigan_ct_from_path(checkpoint_path, device='cuda'):
  54.     model, _, config = load_model_hifigan(checkpoint_path, return_hparams=True)
  55.     #model.half()
  56.     return model, config
  57.  
  58. def check_hparams_match(diffsvc_config, dilated_asr_config, hifigan_config):
  59.     important_params = ('n_mel_channels', 'filter_length', 'hop_length', 'win_length', 'mel_fmin', 'mel_fmax', 'n_symbols')
  60.     for param in important_params:
  61.         assert getattr(diffsvc_config, param, None) == getattr(dilated_asr_config, param, None), f'"{param}" param does not match between diffsvc and dilated_asr. Got {getattr(diffsvc_config, param, None)} and {getattr(dilated_asr_config, param, None)} respectively.'
  62.    
  63.     important_vocoder_params = ('n_mel_channels', 'filter_length', 'hop_length', 'win_length', 'mel_fmin', 'mel_fmax')
  64.     for param in important_vocoder_params:
  65.         assert getattr(diffsvc_config, param, None) == getattr(hifigan_config, param, None), f'"{param}" param does not match between diffsvc and hifigan. Got {getattr(diffsvc_config, param, None)} and {getattr(hifigan_config, param, None)} respectively.'
  66.  
  67. def load_e2e_diffsvc(diffsvc_path, dilated_asr_path, hifigan_path, device='cuda'):
  68.     diffsvc   , diffsvc_config    , speakerlist, spkr_f0, spkr_sylps = load_diffsvc_from_path    (diffsvc_path)
  69.     dilatedasr, dilated_asr_config, speakerlist                      = load_dilated_asr_from_path(dilated_asr_path)
  70.     hifigan   , hifigan_config                                       = load_hifigan_ct_from_path (hifigan_path)
  71.     check_hparams_match(diffsvc_config, dilated_asr_config, hifigan_config)
  72.    
  73.     stft = get_stft(diffsvc_config)
  74.    
  75.     return diffsvc, dilatedasr, hifigan, stft, diffsvc_config, speakerlist, spkr_f0, spkr_sylps,
  76.  
  77. def update_loudness(audio, sampling_rate, target_lufs, max_segment_length_s=30.0):
  78.     meter = pyln.Meter(sampling_rate) # create BS.1770 meter
  79.     original_lufs = meter.integrated_loudness(audio[:int(max_segment_length_s*sampling_rate)].numpy()) # measure loudness (in dB)
  80.     original_lufs = torch.tensor(original_lufs).float()
  81.    
  82.     if type(original_lufs) is torch.Tensor:
  83.         original_lufs = original_lufs.to(audio)
  84.     delta_lufs = target_lufs-original_lufs
  85.     gain = 10.0**(delta_lufs/20.0)
  86.     audio = audio*gain
  87.     if audio.abs().max() > 1.0:
  88.         numel_over_limit = (audio.abs() > 1.0).sum()
  89.         if numel_over_limit > audio.numel()/(sampling_rate/16):# if more than 16 samples per second are over 1.0, do peak normalization. Else just clamp them.
  90.             audio /= audio.abs().max()
  91.         audio.clamp_(min=-1.0, max=1.0)
  92.     return audio
  93.  
  94. def get_audio_from_path(path, config):
  95.     audio, sampling_rate = load_wav_to_torch(path, target_sr=config.sampling_rate)
  96.     audio = update_loudness(audio, sampling_rate, config.target_lufs)
  97.     return audio
  98.  
  99. def get_mel_from_audio(audio, stft, config):
  100.     mel = stft.mel_spectrogram(audio.detach().cpu().unsqueeze(0))
  101.     return mel
  102.  
  103. def get_loudness_from_audio(audio, sampling_rate, config):
  104.     meter = pyln.Meter(sampling_rate) # create BS.1770 meter
  105.     lufs_loudness = meter.integrated_loudness(audio[:int(max_segment_length_s*sampling_rate)].numpy()) # measure loudness (in dB)
  106.     lufs_loudness = torch.tensor(lufs_loudness).float()
  107.     return lufs_loudness
  108.  
  109. def get_pitch(audio, sampling_rate, hop_length, f0_floors=[56.,], f0=None, refine_pitch=True, f0_ceil=1500., voiced_sensitivity=0.13):
  110.     """
  111.    audio: torch.FloatTensor [wav_T]
  112.    sampling_rate: int
  113.    hop_length: int
  114.    f0_floors: list[int]
  115.        - f0_floors is list of minimum pitch values.
  116.          f0 elements of next array replaces previous f0 if elements of previous f0 array are zero
  117.          (aka if the previous f0_floor didn't find any pitch but the next one did, use the next pitch from the next f0_floor)
  118.    """
  119.     if type(f0_floors) in [int, float]:
  120.         f0_floors = [f0_floors,]
  121.     # Extract Pitch/f0 from raw waveform using PyWORLD
  122.     audio = torch.cat((audio, audio[-1:]), dim=0)
  123.     audio = audio.numpy().astype(np.float64)
  124.    
  125.     for f0_floor in f0_floors:
  126.         f0raw, timeaxis = pw.dio(# get raw pitch
  127.             audio, sampling_rate,
  128.             frame_period=(hop_length/sampling_rate)*1000.,# For hop size 256 frame period is 11.6 ms
  129.             f0_floor=f0_floor,# f0_floor : float
  130.                          #     Lower F0 limit in Hz.
  131.                          #     Default: 71.0
  132.             f0_ceil =f0_ceil,# f0_ceil : float
  133.                            #     Upper F0 limit in Hz.
  134.                            #     Default: 800.0
  135.             allowed_range=voiced_sensitivity,# allowed_range : float
  136.                                #     Threshold for voiced/unvoiced decision. Can be any value >= 0, but 0.02 to 0.2
  137.                                #     is a reasonable range. Lower values will cause more frames to be considered
  138.                                #     unvoiced (in the extreme case of `threshold=0`, almost all frames will be unvoiced).
  139.             )
  140.         if refine_pitch:# improves loss values in FastSpeech2 style decoder.
  141.             f0raw = pw.stonemask(audio, f0raw, timeaxis, sampling_rate)# pitch refinement
  142.         f0raw = torch.from_numpy(f0raw).float().clamp(min=0.0, max=f0_ceil)# (Number of Frames) = (654,)
  143.         f0 = f0raw if f0 is None else torch.where(f0==0.0, f0raw, f0)# if current f0 has non-voiced but current f0 has voiced, fill current non-voiced with new voiced pitch.
  144.     voiced_mask = (f0>3)# voice / unvoiced flag
  145.     return f0, voiced_mask# [mel_T], [mel_T]
  146.  
  147. def get_logf0_from_audio(audio, config):
  148.     f0, vo = get_pitch(audio, config.sampling_rate, config.hop_length, getattr(config, 'f0_floors', [55., 78., 110., 156.]), None, refine_pitch=True, f0_ceil=getattr(config, 'f0_ceil', 1500.), voiced_sensitivity=getattr(config, 'voiced_sensitivity', 0.10))
  149.     logf0 = f0.log().where(vo, f0[0]*0.0)
  150.     return logf0
  151.  
  152. def get_ppg_from_mel(mel, model, config, mel_lengths=None):
  153.     model_device, model_dtype = next(model.parameters()).device, next(model.parameters()).dtype
  154.     if mel_lengths is None:
  155.         mel_lengths = torch.tensor([mel.shape[2],]).long()# [B, n_mel, mel_T] -> [mel_T]
  156.     ppg = model.generator.align(mel.to(model_device, model_dtype), mel_lengths.to(model_device))
  157.     return ppg
  158.  
  159. def write_to_file(path, audio, sampling_rate):
  160.     audio = (audio.float() * 2**15).squeeze().cpu().numpy().astype('int16')
  161.     write(path, sampling_rate, audio)
  162.  
  163. def endtoend_from_path(diffsvc, dilatedasr, hifigan, stft, config, speakerlist, spkr_f0, spkr_sylps,
  164.                        audiopath, target_speaker, correct_pitch, t_step_size=1, t_max_step=None):
  165.     audio = get_audio_from_path(audiopath, config)
  166.     pred_audio = endtoend(diffsvc, dilatedasr, hifigan, stft, config, speakerlist, spkr_f0, spkr_sylps,
  167.              audio, target_speaker, correct_pitch, t_step_size=t_step_size, t_max_step=t_max_step)
  168.     return pred_audio
  169.  
  170. def endtoend_from_cache(diffsvc, dilatedasr, hifigan, stft, config, speakerlist, spkr_f0, spkr_sylps,
  171.                        audiopath, target_speaker, correct_pitch, t_step_size=1, t_max_step=None):
  172.     audio = get_audio_from_path(audiopath, config)
  173.     pred_audio = endtoend(diffsvc, dilatedasr, hifigan, stft, config, speakerlist, spkr_f0, spkr_sylps,
  174.              audio, target_speaker, correct_pitch, t_step_size=t_step_size, t_max_step=t_max_step)
  175.     return pred_audio
  176.  
  177. @torch.no_grad()
  178. def endtoend(diffsvc, dilatedasr, hifigan, stft, config, speakerlist, spkr_f0, spkr_sylps,
  179.              audio, target_speaker, correct_pitch, t_step_size=1, t_max_step=None, gt_mel=None, frame_ppg=None, gt_frame_logf0=None):# only supports a single audio file at a time
  180.     if gt_mel is None or frame_ppg is None:
  181.         gt_mel = get_mel_from_audio(audio, stft, config)# [1, n_mel, mel_T]
  182.     if frame_ppg is None:
  183.         frame_ppg = get_ppg_from_mel(gt_mel, dilatedasr, config)
  184.     if gt_frame_logf0 is None:
  185.         gt_frame_logf0 = get_logf0_from_audio(audio, config).unsqueeze(0)
  186.    
  187.     mel_lengths = torch.tensor([gt_mel.shape[2],]).long()
  188.     gt_perc_loudness = torch.tensor([config.target_lufs,])
  189.    
  190.     possible_names = [x[1].lower() for x in speakerlist]
  191.     speaker_lookup = {x[1].lower(): x[2] for x in speakerlist}
  192.     speaker = difflib.get_close_matches(target_speaker.lower(), possible_names, n=2, cutoff=0.01)[0]# get closest name from target_speaker
  193.     print(f"Selected speaker: {speaker}")
  194.     speaker_id_ext = speaker_lookup[speaker]
  195.    
  196.     (speaker_id,
  197.      speaker_f0_meanstd,
  198.      speaker_slyps_meanstd) = speaker_id_ext, spkr_f0[speaker_id_ext], spkr_sylps[speaker_id_ext]
  199.     speaker_id = torch.tensor([speaker_id,]).long()
  200.     speaker_f0_meanstd    = torch.tensor([speaker_f0_meanstd,])
  201.     speaker_slyps_meanstd = torch.tensor([speaker_slyps_meanstd,])
  202.     #print(f"F0 Mean {speaker_f0_meanstd[0, 0].item()} | STD {speaker_f0_meanstd[0, 1].item()}")
  203.     #print(f"SR Mean {speaker_slyps_meanstd[0, 0].item()} | STD {speaker_slyps_meanstd[0, 1].item()}")
  204.    
  205.     if correct_pitch:# correct pitch mean
  206.         correction_shift = speaker_f0_meanstd[:, 0].log()-gt_frame_logf0[gt_frame_logf0!=0.0].float().mean()
  207.         gt_frame_logf0[gt_frame_logf0!=0.0] += correction_shift
  208.    
  209.     if True:# correct pitch scale
  210.         correction_scale = speaker_f0_meanstd[:, 1]/gt_frame_logf0[gt_frame_logf0!=0.0].float().exp().std()
  211.         mean = gt_frame_logf0[gt_frame_logf0!=0.0].mean()
  212.         gt_frame_logf0[gt_frame_logf0!=0.0] = gt_frame_logf0[gt_frame_logf0!=0.0].sub(mean).exp().mul(correction_scale).log().add(mean)
  213.    
  214.     # move all features to correct device + dtype
  215.     diffsvc_device, diff_dtype = next(diffsvc.parameters()).device, next(diffsvc.parameters()).dtype
  216.     gt_mel                = gt_mel               .to(diffsvc_device, diff_dtype)
  217.     gt_perc_loudness      = gt_perc_loudness     .to(diffsvc_device, diff_dtype)
  218.     gt_frame_logf0        = gt_frame_logf0       .to(diffsvc_device, diff_dtype)
  219.     frame_ppg             = frame_ppg            .to(diffsvc_device, diff_dtype)
  220.     mel_lengths           = mel_lengths          .to(diffsvc_device, torch.long)
  221.     speaker_id            = speaker_id           .to(diffsvc_device, torch.long)
  222.     speaker_f0_meanstd    = speaker_f0_meanstd   .to(diffsvc_device, diff_dtype)
  223.     speaker_slyps_meanstd = speaker_slyps_meanstd.to(diffsvc_device, diff_dtype)
  224.    
  225.     pred_mel = diffsvc.generator.voice_conversion_main(
  226.                        gt_mel,  mel_lengths,# FloatTensor[B, n_mel, mel_T], LongTensor[B] # take from reference/source
  227.                            gt_perc_loudness,# FloatTensor[B]                              # take from reference/source
  228.                              gt_frame_logf0,# FloatTensor[B, mel_T]                       # take from reference/source
  229.                                   frame_ppg,# FloatTensor[B, ppg_dim, mel_T]              # take from reference/source
  230.                                  speaker_id,#  LongTensor[B]                              # take from target speaker
  231.                          speaker_f0_meanstd,# FloatTensor[B, 2]                           # take from target speaker
  232.                       speaker_slyps_meanstd,# FloatTensor[B, 2]                           # take from target speaker
  233.                               t_step_size=t_step_size,# int
  234.                                t_max_step=t_max_step).transpose(1, 2)# -> [B, n_mel, mel_T]
  235.    
  236.     hifigan_device, hifigan_dtype = next(hifigan.parameters()).device, next(hifigan.parameters()).dtype
  237.     pred_audio = hifigan(pred_mel.to(hifigan_device, hifigan_dtype))
  238.    
  239.     return pred_audio
  240.  
  241.  
  242.  
  243. # testing
  244. def test_wav():# test the model with data computed from the functions above
  245.     audiopath = ("/media/cookie/Samsung 860 QVO/TTS/"
  246.     "voiceline_2.wav")
  247.     target_speakers = ['Twilight','Pinkie','Discord','Nancy','Yosuke','Adachi']
  248.    
  249.     diffsvc, dilatedasr, hifigan, stft, diffsvc_config, speakerlist, spkr_f0, spkr_sylps, = load_e2e_diffsvc(
  250.         diffsvc_path     = "/media/cookie/Samsung PM961/TwiBot/CookiePPPTTS/CookieTTS/experiments/DiffSVC/outdir_015_7x3/latest_val_model",
  251.         dilated_asr_path = "/media/cookie/Samsung PM961/TwiBot/CookiePPPTTS/CookieTTS/experiments/dilated_ASR/outdir_002/checkpoint_87000",
  252.         hifigan_path     = "/media/cookie/Samsung PM961/TwiBot/CookiePPPTTS/CookieTTS/_4_mtw/hifigan_ct/outdir_u4_warm_oggless/latest_val_model",
  253.     )
  254.    
  255.     lin_start   = 1e-4
  256.     lin_end     = 0.24
  257.     lin_n_steps = 1000
  258.     diffsvc.generator.diffusion.set_noise_schedule(lin_start, lin_end, lin_n_steps, device='cuda')
  259.    
  260.     for correct_pitch in [True,]:
  261.         t_step_size = 1
  262.         for target_speaker in target_speakers:
  263.             for max_t in range(0, lin_n_steps+1, lin_n_steps//2):
  264.                 pred_audio = endtoend_from_path(diffsvc, dilatedasr, hifigan, stft, diffsvc_config, speakerlist, spkr_f0, spkr_sylps,
  265.                                                 audiopath, target_speaker, correct_pitch, t_step_size=t_step_size, t_max_step=max_t)
  266.                
  267.                 outpath = f"/media/cookie/Samsung 860 QVO/TTS/output_spkr{target_speaker}_max{max_t:04}_step{t_step_size}_{'mod' if correct_pitch else 'orig'}pitch.wav"
  268.                 write_to_file(outpath, pred_audio, diffsvc_config.sampling_rate)
  269.                 print(f"Wrote audio to '{outpath}'")
  270.  
  271. test_wav()
  272.  
Add Comment
Please, Sign In to add comment