Advertisement
alseambusher

Final3

Aug 8th, 2019
11,925
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.07 KB | None | 0 0
  1. import skvideo.io
  2. import skvideo.datasets
  3. import tensorflow as tf
  4. from tensorflow.keras.applications.resnet50 import ResNet50
  5. from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
  6. import cv2
  7. import json
  8. import numpy as np
  9.  
  10. # enabling eager execution for easier explanation
  11. tf.enable_eager_execution()
  12.  
  13. model = ResNet50(weights='imagenet')
  14.  
  15. reader = skvideo.io.FFmpegReader(skvideo.datasets.bigbuckbunny(),
  16.                              inputdict={},
  17.                              outputdict={})
  18. def gen_frames():
  19.     for frame in reader.nextFrame():
  20.         yield frame
  21.  
  22. dataset = tf.data.Dataset.from_generator(gen_frames, tf.int64)
  23.  
  24. def preprocess(frame):
  25.     x = tf.image.resize_bilinear(frame, [224, 224])
  26.     x = preprocess_input(x)
  27.     return x, frame
  28.  
  29. dataset = dataset.batch(64).map(preprocess, 10).prefetch(1)
  30.  
  31. def predict():
  32.     with tf.device("/gpu:0"):
  33.         for frames, original in dataset:
  34.             yield model.predict(frames.numpy()), original
  35.  
  36. dataset2 = tf.data.Dataset.from_generator(predict, (tf.float64, tf.int64))
  37.  
  38. def postprocess(output, original):
  39.     # do some post processing
  40.     return tf.argsort(output)[:3], original
  41.  
  42. dataset2 = dataset2.apply(tf.data.experimental.unbatch()).map(postprocess, 10)
  43.  
  44. with open("imagenet_class_index.json") as f:
  45.     CLASS_INDEX = json.load(f)
  46.  
  47. writer = skvideo.io.FFmpegWriter("output.mp4")
  48. for value, value2 in dataset2:
  49.     indices = value.numpy()
  50.     f = value2.numpy()
  51.     f = np.ascontiguousarray(f, dtype=np.uint8)
  52.     cv2.putText(f, CLASS_INDEX[str(indices[0])][1], (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), lineType=cv2.LINE_AA)
  53.     cv2.putText(f, CLASS_INDEX[str(indices[1])][1], (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), lineType=cv2.LINE_AA)
  54.     cv2.putText(f, CLASS_INDEX[str(indices[2])][1], (20, 180), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), lineType=cv2.LINE_AA)
  55.     writer.writeFrame(f)
  56.     print(CLASS_INDEX[str(indices[0])][1], CLASS_INDEX[str(indices[1])][1], CLASS_INDEX[str(indices[2])][1])
  57. writer.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement