Advertisement
Guest User

Untitled

a guest
Jul 19th, 2018
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.94 KB | None | 0 0
  1. import json
  2. import os
  3. import numpy as np
  4. import logging
  5. import logging.config
  6. import sys
  7. import argparse
  8.  
  9. def _parse_args():
  10.     parser = argparse.ArgumentParser(description='create dataset for end of utterance')
  11.     parser.add_argument(
  12.         '--asr'
  13.     )
  14.  
  15.     parser.add_argument(
  16.         '--align'
  17.     )
  18.  
  19.     parser.add_argument(
  20.         '--output'
  21.     )
  22.     return parser.parse_args()
  23.  
  24. def _main(args):
  25.     ASR_RESULT_DIR = args.asr
  26.     ALIGNEMENT_DIR = args.align
  27.     OUTPUT_PATH = args.output
  28.    
  29.     ch = logging.StreamHandler(sys.stderr)
  30.     ch.setLevel(logging.INFO)
  31.    
  32.     formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  33.     ch.setFormatter(formatter)
  34.    
  35.     logger = logging.Logger("")
  36.  
  37.     logger.addHandler(ch)
  38.  
  39.     def process_asr(asr_result_dir):
  40.         asr_feats = {}
  41.         with open(asr_result_dir, encoding='utf8') as f:
  42.             for line in f:
  43.                 as_json = json.loads(line.replace('nan', 'NaN').replace('inf', '1000'))
  44.                 for j in as_json['chunks']:
  45.                     if 'eou_detectors' in j:
  46.                         if 'features' not in j['eou_detectors'][0]['NeuralDetector']:
  47.                           continue
  48.                         asr_feats[as_json['id'].split('.wav')[0]] = np.reshape(
  49.                             np.array(j['eou_detectors'][0]['NeuralDetector']['features']
  50.                                     )
  51.                             .astype(float), (-1, 6))
  52.         logger.info('Finished processing asr feats, total length: {}'.format(len(asr_feats)))
  53.         return asr_feats
  54.  
  55.     def process_alignment(alignment_dir):
  56.         aligned_feats = {}
  57.         with open(alignment_dir, encoding='utf-8') as f:
  58.             aligned = json.loads(f.readline())
  59.             for line in aligned:
  60.                 index = line['id']
  61.                 word_times = line['words']
  62.                 starts = np.array(word_times)[:, 1].astype(float)
  63.                 ends = np.array(word_times)[:, 2].astype(float)
  64.                 if (len(starts) <= 1):
  65.                     aligned_feats[index] = round(sorted(word_times, key=lambda k: k[-1])[-1][-1] * 100)
  66.                     continue
  67.                 max_pause_time = np.max(np.abs(ends[:-1] - starts[1:]))
  68.                 if max_pause_time >= 1.4:
  69.                     continue
  70.                 aligned_feats[index] = round(sorted(word_times, key=lambda k: k[-1])[-1][-1] * 100)
  71.         logger.info('Finished processing alignment, total length: {}'.format(len(aligned_feats)))
  72.         return aligned_feats
  73.  
  74.  
  75.     def get_features(asr_feats, probas_feats, aligned_feats):
  76.         out_dict = {}
  77.         for index in asr_feats:
  78.             if index not in aligned_feats:
  79.                 continue
  80.             asr = asr_feats[index]
  81.             align = aligned_feats[index]
  82.             cutoff = proba.shape[0] - asr.shape[0]
  83.             if asr.shape[0] != proba[cutoff:].shape[0]:
  84.               continue
  85.             feats = np.concatenate((asr, np.reshape(proba[cutoff:], (-1, 1))), axis=1)
  86.             label = align - cutoff + 10
  87.             if label > feats.shape[0]:
  88.                 continue
  89.             label_packed = np.reshape(np.repeat(label, feats.shape[-1]), (1, -1))
  90.             out_dict[index] = np.concatenate((feats, label_packed), axis=0)
  91.  
  92.         logger.info('Finished processing features, total length: {}'.format(len(out_dict)))
  93.         return out_dict
  94.  
  95.     asr_feats = process_asr(ASR_RESULT_DIR)
  96.     probas_feats = process_probas(PROBAS_DIR)
  97.     aligned_feats = process_alignment(ALIGNEMENT_DIR)
  98.  
  99.  
  100.     feats = get_features(asr_feats=asr_feats,
  101.                          probas_feats=probas_feats,
  102.                          aligned_feats=aligned_feats)
  103.     logger.info('Writing to {}'.format(OUTPUT_PATH))
  104.     with open(OUTPUT_PATH, "w") as fout:
  105.         np.save(OUTPUT_PATH, feats)
  106.     logger.info('Written to {}'.format(OUTPUT_PATH))
  107.  
  108. if __name__ == '__main__':
  109.     _main(_parse_args())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement