Advertisement
DanialAhmed

object_detection.py

Dec 11th, 2019
202
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.44 KB | None | 0 0
  1. import numpy as np
  2. import os
  3. import six.moves.urllib as urllib
  4. import sys
  5. import tarfile
  6. import tensorflow as tf
  7. import zipfile
  8. import datetime
  9.  
  10. from collections import defaultdict
  11. from io import StringIO
  12. from matplotlib import pyplot as plt
  13. from PIL import Image
  14. from utils import label_map_util
  15. from EmailSender import SendEmail
  16.  
  17. sys.path.insert(0,r'C:\darknet')
  18. from darknet import *
  19. from utils import visualization_utils as vis_util
  20. import cv2
  21. def AccidentDetector(videofile):
  22.  
  23. cap = cv2.VideoCapture(videofile)
  24. #MODEL_NAME = 'Accident_Detection25487-resnet'
  25. MODEL_NAME = 'Accident_Detection42214-resnet'
  26. #MODEL_NAME = 'Accident_Detection200000'
  27. MODEL_FILE = MODEL_NAME + '.tar.gz'
  28. DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
  29. PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
  30. PATH_TO_LABELS = os.path.join('data', 'object-detection.pbtxt')
  31. NUM_CLASSES = 2
  32. detection_graph = tf.Graph()
  33. with detection_graph.as_default():
  34. od_graph_def = tf.GraphDef()
  35. with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  36. serialized_graph = fid.read()
  37. od_graph_def.ParseFromString(serialized_graph)
  38. tf.import_graph_def(od_graph_def, name='')
  39. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  40. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
  41. category_index = label_map_util.create_category_index(categories)
  42. def load_image_into_numpy_array(image):
  43. (im_width, im_height) = image.size
  44. return np.array(image.getdata()).reshape(
  45. (im_height, im_width, 3)).astype(np.uint8)
  46. PATH_TO_TEST_IMAGES_DIR = 'test_images'
  47. TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
  48. IMAGE_SIZE = (12, 8)
  49.  
  50.  
  51. check=0
  52. count=0
  53. prevdate=datetime.datetime.now()
  54. #currdate=datetime.datetime.now()
  55. with detection_graph.as_default():
  56. with tf.Session(graph=detection_graph) as sess:
  57. while True:
  58. ret, image_np = cap.read()
  59. imageOrg=image_np
  60. image_np_expanded = np.expand_dims(image_np, axis=0)
  61. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  62. boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  63. scores = detection_graph.get_tensor_by_name('detection_scores:0')
  64. classes = detection_graph.get_tensor_by_name('detection_classes:0')
  65. num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  66.  
  67. (boxes, scores, classes, num_detections) = sess.run(
  68. [boxes, scores, classes, num_detections],
  69. feed_dict={image_tensor: image_np_expanded})
  70. vis_util.visualize_boxes_and_labels_on_image_array(
  71. image_np,
  72. np.squeeze(boxes),
  73. np.squeeze(classes).astype(np.int32),
  74. np.squeeze(scores),
  75. category_index,
  76. use_normalized_coordinates=True,
  77. line_thickness=8)
  78. for index,value in enumerate(classes[0]):
  79. if scores[0,index] > 0.5:
  80. list1 = [[category_index.get(value)]]
  81. for i in list1:
  82. for j in i:
  83. if j['name'] == 'accident':
  84. if(check == 0):
  85. check=1
  86. #SendEmail(currdate)
  87. ts = datetime.datetime.now().timestamp()
  88. file = "detections\\"+str(ts)+".jpg"
  89. cv2.imwrite(file,imageOrg)
  90. # 1576044803.3581
  91. data,check = performDetect(imagePath=file)
  92. print(data,check)
  93. #name = "detections/frame%d.jpg"%count
  94. count = count + 1
  95.  
  96. cv2.imshow(videofile[0], cv2.resize(image_np, (600,600)))
  97. if cv2.waitKey(20) & 0xFF == ord('n'):
  98. current = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
  99. print(current)
  100. cap.set(cv2.CAP_PROP_POS_FRAMES,current+50)
  101.  
  102. if cv2.waitKey(20) & 0xFF == ord('p'):
  103. current = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
  104. print(current)
  105. cap.set(cv2.CAP_PROP_POS_FRAMES,current-50)
  106.  
  107. if cv2.waitKey(20) & 0xFF == ord('q'):
  108. cv2.destroyAllWindows()
  109. break
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement