Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import json
- import os
- import numpy as np
- import logging
- import logging.config
- import sys
- import argparse
- def _parse_args():
- parser = argparse.ArgumentParser(description='create dataset for end of utterance')
- parser.add_argument(
- '--asr'
- )
- parser.add_argument(
- '--align'
- )
- parser.add_argument(
- '--output'
- )
- return parser.parse_args()
- def _main(args):
- ASR_RESULT_DIR = args.asr
- ALIGNEMENT_DIR = args.align
- OUTPUT_PATH = args.output
- ch = logging.StreamHandler(sys.stderr)
- ch.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
- ch.setFormatter(formatter)
- logger = logging.Logger("")
- logger.addHandler(ch)
- def process_asr(asr_result_dir):
- asr_feats = {}
- with open(asr_result_dir, encoding='utf8') as f:
- for line in f:
- as_json = json.loads(line.replace('nan', 'NaN').replace('inf', '1000'))
- for j in as_json['chunks']:
- if 'eou_detectors' in j:
- if 'features' not in j['eou_detectors'][0]['NeuralDetector']:
- continue
- asr_feats[as_json['id'].split('.wav')[0]] = np.reshape(
- np.array(j['eou_detectors'][0]['NeuralDetector']['features']
- )
- .astype(float), (-1, 6))
- logger.info('Finished processing asr feats, total length: {}'.format(len(asr_feats)))
- return asr_feats
- def process_alignment(alignment_dir):
- aligned_feats = {}
- with open(alignment_dir, encoding='utf-8') as f:
- aligned = json.loads(f.readline())
- for line in aligned:
- index = line['id']
- word_times = line['words']
- starts = np.array(word_times)[:, 1].astype(float)
- ends = np.array(word_times)[:, 2].astype(float)
- if (len(starts) <= 1):
- aligned_feats[index] = round(sorted(word_times, key=lambda k: k[-1])[-1][-1] * 100)
- continue
- max_pause_time = np.max(np.abs(ends[:-1] - starts[1:]))
- if max_pause_time >= 1.4:
- continue
- aligned_feats[index] = round(sorted(word_times, key=lambda k: k[-1])[-1][-1] * 100)
- logger.info('Finished processing alignment, total length: {}'.format(len(aligned_feats)))
- return aligned_feats
- def get_features(asr_feats, probas_feats, aligned_feats):
- out_dict = {}
- for index in asr_feats:
- if index not in aligned_feats:
- continue
- asr = asr_feats[index]
- align = aligned_feats[index]
- cutoff = proba.shape[0] - asr.shape[0]
- if asr.shape[0] != proba[cutoff:].shape[0]:
- continue
- feats = np.concatenate((asr, np.reshape(proba[cutoff:], (-1, 1))), axis=1)
- label = align - cutoff + 10
- if label > feats.shape[0]:
- continue
- label_packed = np.reshape(np.repeat(label, feats.shape[-1]), (1, -1))
- out_dict[index] = np.concatenate((feats, label_packed), axis=0)
- logger.info('Finished processing features, total length: {}'.format(len(out_dict)))
- return out_dict
- asr_feats = process_asr(ASR_RESULT_DIR)
- probas_feats = process_probas(PROBAS_DIR)
- aligned_feats = process_alignment(ALIGNEMENT_DIR)
- feats = get_features(asr_feats=asr_feats,
- probas_feats=probas_feats,
- aligned_feats=aligned_feats)
- logger.info('Writing to {}'.format(OUTPUT_PATH))
- with open(OUTPUT_PATH, "w") as fout:
- np.save(OUTPUT_PATH, feats)
- logger.info('Written to {}'.format(OUTPUT_PATH))
- if __name__ == '__main__':
- _main(_parse_args())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement