nikich340

inference_batch_from_csv.py

Nov 5th, 2021 (edited)
234
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.44 KB | None | 0 0
  1. ###############################################################################
  2. #
  3. #  Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  4. #  Licensed under the Apache License, Version 2.0 (the "License");
  5. #  you may not use this file except in compliance with the License.
  6. #  You may obtain a copy of the License at
  7. #
  8. #      http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. #  Unless required by applicable law or agreed to in writing, software
  11. #  distributed under the License is distributed on an "AS IS" BASIS,
  12. #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. #  See the License for the specific language governing permissions and
  14. #  limitations under the License.
  15. #
  16. ###############################################################################
  17. import matplotlib
  18. matplotlib.use("Agg")
  19. import matplotlib.pylab as plt
  20.  
  21. import os
  22. import argparse
  23. import json
  24. import sys
  25. import numpy as np
  26. import torch
  27. import csv
  28.  
  29.  
  30. from flowtron import Flowtron
  31. from torch.utils.data import DataLoader
  32. from data import Data
  33. from train import update_params
  34.  
  35. sys.path.insert(0, "tacotron2")
  36. sys.path.insert(0, "tacotron2/waveglow")
  37. from glow import WaveGlow
  38. from scipy.io.wavfile import write
  39.  
  40.  
  41. def load_text_batch(csv_path, text_lines, text_ids):
  42.     with open(csv_path) as csv_file:
  43.         csv_reader = csv.reader(csv_file, delimiter='|')
  44.         line_count = 0
  45.         for row in csv_reader:
  46.             line_count += 1
  47.             if line_count == 1:
  48.                 continue
  49.             else:
  50.                 text_ids.append(row[0])
  51.                 text_lines.append(row[1])
  52.     print("Batch: loaded {} lines from file.".format(line_count))
  53.    
  54. def infer(flowtron_path, waveglow_path, output_dir, csv_path, speaker_id, n_frames,
  55.           sigma, gate_threshold, seed):
  56.     torch.manual_seed(seed)
  57.     torch.cuda.manual_seed(seed)
  58.  
  59.     # load waveglow
  60.     waveglow = torch.load(waveglow_path)['model'].cuda().eval()
  61.     waveglow.cuda().half()
  62.     for k in waveglow.convinv:
  63.         k.float()
  64.     waveglow.eval()
  65.  
  66.     # load flowtron
  67.     try:
  68.         model = Flowtron(**model_config).cuda()
  69.         state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict']
  70.         model.load_state_dict(state_dict)
  71.     except KeyError:
  72.         # model saved by train.py module
  73.         # do not need to load state dict
  74.         # and can be used directly
  75.         model = torch.load(flowtron_path)['model']
  76.     model.eval()
  77.     print("Loaded checkpoint '{}')" .format(flowtron_path))
  78.  
  79.     ignore_keys = ['training_files', 'validation_files']
  80.     trainset = Data(
  81.         data_config['training_files'],
  82.         **dict((k, v) for k, v in data_config.items() if k not in ignore_keys))
  83.    
  84.     speaker_vecs = trainset.get_speaker_id(speaker_id).cuda()
  85.     speaker_vecs = speaker_vecs[None]
  86.    
  87.     with torch.no_grad():
  88.         residual = torch.cuda.FloatTensor(1, 80, n_frames).normal_() * sigma
  89.    
  90.     text_lines = []
  91.     text_ids = []
  92.     load_text_batch(csv_path, text_lines, text_ids)
  93.    
  94.     for idx in range(len(text_lines)):
  95.         text = text_lines[idx]
  96.         file_suffix = text_ids[idx] # + (text[:20] + "..") if len(text) > 20 else text
  97.         text = trainset.get_text(text).cuda()
  98.        
  99.         text = text[None]
  100.  
  101.         with torch.no_grad():
  102.             #residual = torch.cuda.FloatTensor(1, 80, n_frames).normal_() * sigma
  103.             mels, attentions = model.infer(
  104.                 residual, speaker_vecs, text, gate_threshold=gate_threshold)
  105.  
  106.         for k in range(len(attentions)):
  107.             attention = torch.cat(attentions[k]).cpu().numpy()
  108.             fig, axes = plt.subplots(1, 2, figsize=(16, 4))
  109.             #axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto')
  110.             #axes[1].imshow(attention[:, 0].transpose(), origin='bottom', aspect='auto')
  111.             axes[0].imshow(mels[0].cpu().numpy(), origin='lower', aspect='auto')
  112.             axes[1].imshow(attention[:, 0].transpose(), origin='lower', aspect='auto')
  113.             #fig.savefig(os.path.join(output_dir, 'sid{}_sigma{}_frames{}_attnlayer{}.png'.format(speaker_id, sigma, n_frames, k)))
  114.             plt.close("all")
  115.  
  116.         with torch.no_grad():
  117.             audio = waveglow.infer(mels.half(), sigma=0.8).float()
  118.  
  119.         audio = audio.cpu().numpy()[0]
  120.         # normalize audio for now
  121.         audio = audio / np.abs(audio).max()
  122.  
  123.         write(os.path.join(output_dir, 'sid{}_s{}_{}.wav'.format(speaker_id, sigma, file_suffix)),
  124.               data_config['sampling_rate'], audio)
  125.        
  126.         print("Ready! audio: {}, text: {}".format(audio.shape, text_lines[idx]))
  127.     print("Finished inferencing {} lines.".format(len(text_lines)))
  128.  
  129.  
  130. if __name__ == "__main__":
  131.     parser = argparse.ArgumentParser()
  132.     parser.add_argument('-c', '--config', type=str,
  133.                         help='JSON file for configuration')
  134.     parser.add_argument('-p', '--params', nargs='+', default=[])
  135.     parser.add_argument('-f', '--flowtron_path',
  136.                         help='Path to flowtron state dict', type=str)
  137.     parser.add_argument('-w', '--waveglow_path',
  138.                         help='Path to waveglow state dict', type=str)
  139.     parser.add_argument("--csv_path", default='batch_lines.csv', type=str, help='Path to csv file')
  140.     parser.add_argument('-i', '--id', help='Speaker id', type=int)
  141.     parser.add_argument('-n', '--n_frames', help='Number of frames',
  142.                         default=400, type=int)
  143.     parser.add_argument('-o', "--output_dir", default="results/")
  144.     parser.add_argument("-s", "--sigma", default=0.5, type=float)
  145.     parser.add_argument("-g", "--gate", default=0.5, type=float)
  146.     parser.add_argument("--seed", default=1234, type=int)
  147.     args = parser.parse_args()
  148.  
  149.     # Parse configs.  Globals nicer in this case
  150.     with open(args.config) as f:
  151.         data = f.read()
  152.  
  153.     global config
  154.     config = json.loads(data)
  155.     update_params(config, args.params)
  156.  
  157.     data_config = config["data_config"]
  158.     global model_config
  159.     model_config = config["model_config"]
  160.  
  161.     # Make directory if it doesn't exist
  162.     if not os.path.isdir(args.output_dir):
  163.         os.makedirs(args.output_dir)
  164.         os.chmod(args.output_dir, 0o775)
  165.  
  166.     torch.backends.cudnn.enabled = True
  167.     torch.backends.cudnn.benchmark = False
  168.     infer(args.flowtron_path, args.waveglow_path, args.output_dir, args.csv_path,
  169.           args.id, args.n_frames, args.sigma, args.gate, args.seed)
  170.  
Add Comment
Please, Sign In to add comment