Guest User

Untitled

a guest
Nov 19th, 2018
140
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.35 KB | None | 0 0
  1. import os
  2. import cv2
  3. import numpy as np
  4. from picamera.array import PiRGBArray
  5. from picamera import PiCamera
  6. import tensorflow as tf
  7. import argparse
  8. import sys
  9.  
  10. # Set up camera constants
  11. IM_WIDTH = 1280
  12. IM_HEIGHT = 720
  13. #IM_WIDTH = 640 Use smaller resolution for
  14. #IM_HEIGHT = 480 slightly faster framerate
  15.  
  16. # Select camera type (if user enters --usbcam when calling this script,
  17. # a USB webcam will be used)
  18. camera_type = 'picamera'
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument('--usbcam', help='Use a USB webcam instead of picamera',
  21. action='store_true')
  22. args = parser.parse_args()
  23. if args.usbcam:
  24. camera_type = 'usb'
  25.  
  26. # This is needed since the working directory is the object_detection folder.
  27. sys.path.append('..')
  28.  
  29. # Import utilites
  30. from utils import label_map_util
  31. from utils import visualization_utils as vis_util
  32.  
  33. # Name of the directory containing the object detection module we're using
  34. MODEL_NAME = 'ssdlite_mobilenet_v2_coco_2018_05_09'
  35.  
  36. # Grab path to current working directory
  37. CWD_PATH = os.getcwd()
  38.  
  39. # Path to frozen detection graph .pb file, which contains the model that is used
  40. # for object detection.
  41. PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,'frozen_inference_graph.pb')
  42.  
  43. # Path to label map file
  44. PATH_TO_LABELS = os.path.join(CWD_PATH,'data','mscoco_label_map.pbtxt')
  45.  
  46. # Number of classes the object detector can identify
  47. NUM_CLASSES = 90
  48.  
  49. ## Load the label map.
  50. # Label maps map indices to category names, so that when the convolution
  51. # network predicts `5`, we know that this corresponds to `airplane`.
  52. # Here we use internal utility functions, but anything that returns a
  53. # dictionary mapping integers to appropriate string labels would be fine
  54. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  55. categories = label_map_util.convert_label_map_to_categories(label_map,
  56. max_num_classes=NUM_CLASSES, use_display_name=True)
  57. category_index = label_map_util.create_category_index(categories)
  58.  
  59. # Load the Tensorflow model into memory.
  60. detection_graph = tf.Graph()
  61. with detection_graph.as_default():
  62. od_graph_def = tf.GraphDef()
  63. with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  64. serialized_graph = fid.read()
  65. od_graph_def.ParseFromString(serialized_graph)
  66. tf.import_graph_def(od_graph_def, name='')
  67.  
  68. sess = tf.Session(graph=detection_graph)
  69.  
  70.  
  71. # Define input and output tensors (i.e. data) for the object detection
  72. classifier
  73.  
  74. # Input tensor is the image
  75. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  76.  
  77. # Output tensors are the detection boxes, scores, and classes
  78. # Each box represents a part of the image where a particular object was detected
  79. detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  80.  
  81. # Each score represents level of confidence for each of the objects.
  82. # The score is shown on the result image, together with the class label.
  83. detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
  84. detection_classes =
  85. detection_graph.get_tensor_by_name('detection_classes:0')
  86.  
  87. # Number of objects detected
  88. num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  89.  
  90. # Initialize frame rate calculation
  91. frame_rate_calc = 1
  92. freq = cv2.getTickFrequency()
  93. font = cv2.FONT_HERSHEY_SIMPLEX
  94.  
  95. # Initialize camera and perform object detection.
  96. # The camera has to be set up and used differently depending on if it's a
  97. # Picamera or USB webcam.
  98.  
  99. # I know this is ugly, but I basically copy+pasted the code for the object
  100. # detection loop twice, and made one work for Picamera and the other work
  101. # for USB.
  102.  
  103. ### Picamera ###
  104. if camera_type == 'picamera':
  105. # Initialize Picamera and grab reference to the raw capture
  106. camera = PiCamera()
  107. camera.resolution = (IM_WIDTH,IM_HEIGHT)
  108. camera.framerate = 10
  109. rawCapture = PiRGBArray(camera, size=(IM_WIDTH,IM_HEIGHT))
  110. rawCapture.truncate(0)
  111.  
  112. for frame1 in camera.capture_continuous(rawCapture, format="bgr",use_video_port=True):
  113.  
  114. t1 = cv2.getTickCount()
  115.  
  116. # Acquire frame and expand frame dimensions to have shape: [1, None, None, 3]
  117. # i.e. a single-column array, where each item in the column has the pixel RGB value
  118. frame = frame1.array
  119. frame.setflags(write=1)
  120. frame_expanded = np.expand_dims(frame, axis=0)
  121.  
  122. # Perform the actual detection by running the model with the image as input
  123. (boxes, scores, classes, num) = sess.run(
  124. [detection_boxes, detection_scores, detection_classes, num_detections],
  125. feed_dict={image_tensor: frame_expanded})
  126.  
  127. # Draw the results of the detection (aka 'visulaize the results')
  128. vis_util.visualize_boxes_and_labels_on_image_array(
  129. frame,
  130. np.squeeze(boxes),
  131. np.squeeze(classes).astype(np.int32),
  132. np.squeeze(scores),
  133. category_index,
  134. use_normalized_coordinates=True,
  135. line_thickness=8,
  136. min_score_thresh=0.40)
  137.  
  138. # Blue line
  139. cv2.line(frame, (IM_WIDTH // 2, 0), (IM_WIDTH // 2 , IM_WIDTH), (250, 0, 1), 2)
  140. # Red line
  141. cv2.line(frame, (IM_WIDTH // 2 - 50, 0), (IM_WIDTH // 2 - 50, IM_WIDTH), (0, 0, 255), 2)
  142.  
  143. # FPS Text
  144. cv2.putText(frame,"FPS: {0:.2f}".format(frame_rate_calc),(30,50),font,1,(255,255,0),2,cv2.LINE_AA)
  145.  
  146. # All the results have been drawn on the frame, so it's time to display it.
  147. cv2.imshow('Object detector', frame)
  148.  
  149. t2 = cv2.getTickCount()
  150. time1 = (t2-t1)/freq
  151. frame_rate_calc = 1/time1
  152.  
  153. # Press 'q' to quit
  154. if cv2.waitKey(1) == ord('q'):
  155. break
  156.  
  157. rawCapture.truncate(0)
  158.  
  159. camera.close()
Add Comment
Please, Sign In to add comment