daily pastebin goal
65%
SHARE
TWEET

Untitled

a guest Mar 24th, 2019 82 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. """
  2. Usage:
  3.  
  4. python3 demo/classify_capture_opencv.py \
  5.     --model test_data/inception_v4_299_quant_edgetpu.tflite  \
  6.     --label test_data/imagenet_labels.txt
  7.  
  8. """
  9. import argparse
  10. import io
  11. import time
  12.  
  13. import numpy as np
  14.  
  15. import cv2
  16.  
  17. import edgetpu.classification.engine
  18.  
  19.  
  20. def main():
  21.     parser = argparse.ArgumentParser()
  22.     parser.add_argument(
  23.       '--model', help='File path of Tflite model.', required=True)
  24.     parser.add_argument(
  25.       '--label', help='File path of label file.', required=True)
  26.     args = parser.parse_args()
  27.  
  28.     with open(args.label, 'r') as f:
  29.         pairs = (l.strip().split(maxsplit=1) for l in f.readlines())
  30.         labels = dict((int(k), v) for k, v in pairs)
  31.  
  32.     engine = edgetpu.classification.engine.ClassificationEngine(args.model)
  33.  
  34.     try:
  35.         cap = cv2.VideoCapture(0)
  36.         font = cv2.FONT_HERSHEY_SIMPLEX
  37.         _, width, height, channels = engine.get_input_tensor_shape()
  38.         while True:
  39.             ret, frame = cap.read()
  40.  
  41.             # Display the resulting frame
  42.             if cv2.waitKey(1) & 0xFF == ord('q'):
  43.                 break
  44.  
  45.             resized = cv2.resize(frame, (width, height))
  46.             input = np.frombuffer(resized, dtype=np.uint8)
  47.             start_time = time.time()
  48.             results = engine.ClassifyWithInputTensor(input, top_k=1)
  49.             elapsed_time = time.time() - start_time
  50.             if results:
  51.                 confidence = results[0][1]
  52.                 label = labels[results[0][0]]
  53.                 print("Elapsed time: {:0.02f}".format(elapsed_time * 1000))
  54.             cv2.putText(frame, label, (0, 30), font, 1, (255, 255, 255), 2, cv2.LINE_AA)
  55.             cv2.putText(frame, "{:0.02f}".format(confidence), (0, 50), font, 1, (255, 255, 255), 2, cv2.LINE_AA)
  56.             cv2.imshow('frame', frame)
  57.     finally:
  58.         cap.release()
  59.         cv2.destroyAllWindows()
  60.  
  61.  
  62. if __name__ == '__main__':
  63.     main()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top