Advertisement
Guest User

Untitled

a guest
May 27th, 2017
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.22 KB | None | 0 0
  1. import numpy as np
  2. import tensorflow as tf
  3. import os
  4. import glob
  5.  
  6. def tf2npz(tf_path, export_folder='/ssd/yt8m/data_npz/'):
  7. vid_ids = []
  8. labels = []
  9. mean_rgb = []
  10. mean_audio = []
  11. tf_basename = os.path.basename(tf_path)
  12. npz_basename = tf_basename[:-len('.tfrecord')] + '.npz'
  13. isTrain = '/test' not in tf_path
  14.  
  15. for example in tf.python_io.tf_record_iterator(tf_path):
  16. tf_example = tf.train.Example.FromString(example).features
  17. vid_ids.append(tf_example.feature['video_id'].bytes_list.value[0].decode(encoding='UTF-8'))
  18. if isTrain:
  19. labels.append(np.array(tf_example.feature['labels'].int64_list.value))
  20. mean_rgb.append(np.array(tf_example.feature['mean_rgb'].float_list.value).astype(np.float16))
  21. mean_audio.append(np.array(tf_example.feature['mean_audio'].float_list.value).astype(np.float16))
  22.  
  23. save_path = export_folder + '/' + npz_basename
  24. np.savez(save_path,
  25. rgb=np.array(mean_rgb),
  26. audio=np.array(mean_audio),
  27. ids=np.array(vid_ids),
  28. labels=labels
  29. )
  30.  
  31. from multiprocessing import Pool
  32. with Pool(6) as p:
  33. p.map(tf2npz, glob.glob('/ssd/yt8m/data_tfrecord/*.tfrecord'))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement