Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import tensorflow as tf
- import os
- import glob
- def tf2npz(tf_path, export_folder='/ssd/yt8m/data_npz/'):
- vid_ids = []
- labels = []
- mean_rgb = []
- mean_audio = []
- tf_basename = os.path.basename(tf_path)
- npz_basename = tf_basename[:-len('.tfrecord')] + '.npz'
- isTrain = '/test' not in tf_path
- for example in tf.python_io.tf_record_iterator(tf_path):
- tf_example = tf.train.Example.FromString(example).features
- vid_ids.append(tf_example.feature['video_id'].bytes_list.value[0].decode(encoding='UTF-8'))
- if isTrain:
- labels.append(np.array(tf_example.feature['labels'].int64_list.value))
- mean_rgb.append(np.array(tf_example.feature['mean_rgb'].float_list.value).astype(np.float16))
- mean_audio.append(np.array(tf_example.feature['mean_audio'].float_list.value).astype(np.float16))
- save_path = export_folder + '/' + npz_basename
- np.savez(save_path,
- rgb=np.array(mean_rgb),
- audio=np.array(mean_audio),
- ids=np.array(vid_ids),
- labels=labels
- )
- from multiprocessing import Pool
- with Pool(6) as p:
- p.map(tf2npz, glob.glob('/ssd/yt8m/data_tfrecord/*.tfrecord'))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement