SHARE
TWEET

Tensorflow Object Detection Webcam

mdan Mar 16th, 2018 178 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # Hai sa importam cele de trebuinta
  2.  
  3. import numpy as np
  4. import os
  5. import six.moves.urllib as urllib
  6. import sys
  7. import tarfile
  8. import tensorflow as tf
  9. import zipfile
  10.  
  11. from collections import defaultdict
  12. from io import StringIO
  13. from matplotlib import pyplot as plt
  14. from PIL import Image
  15.  
  16.  
  17. from utils import label_map_util
  18.  
  19. from utils import visualization_utils as vis_util
  20.  
  21. # Pregatim modelul
  22.  
  23. '''
  24. Orice model exportat folosind 'export_inference_graph.py'
  25. poate fi incarcat aici prin schimbarea 'PATH_TO_CKPT'
  26. in asa fel incat noua destinatie sa corespunda
  27. noului fisier .pb
  28.  
  29. Eu am folosit aici modelul 'SSD with Mobilenet'
  30. '''
  31.  
  32. # Stabilim ce model sa downloadam.
  33. MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
  34. MODEL_FILE = MODEL_NAME + '.tar.gz'
  35. DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
  36.  
  37. #Path-ul catre modelul folosit pentru identificarea obiectelor
  38. #In cazul nostru, frozen detection graph
  39.  
  40. PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
  41.  
  42. #Path-ul catre lista de denumiri ale obiectelor
  43.  
  44. PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
  45.  
  46. NUM_CLASSES = 90
  47.  
  48.  
  49. # Acush sa download modelu'
  50.  
  51. if not os.path.exists(MODEL_NAME + '/frozen_inference_graph.pb'):
  52.     print ('Downloading the model')
  53.     opener = urllib.request.URLopener()
  54.     opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
  55.     tar_file = tarfile.open(MODEL_FILE)
  56.     for file in tar_file.getmembers():
  57.       file_name = os.path.basename(file.name)
  58.       if 'frozen_inference_graph.pb' in file_name:
  59.         tar_file.extract(file, os.getcwd())
  60.     print ('Download complete')
  61. else:
  62.     print ('Model already exists')
  63.  
  64. # Sa bagam in memorie un model Tensorflow (frozen)
  65.  
  66. detection_graph = tf.Graph()
  67. with detection_graph.as_default():
  68.   od_graph_def = tf.GraphDef()
  69.   with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  70.     serialized_graph = fid.read()
  71.     od_graph_def.ParseFromString(serialized_graph)
  72.     tf.import_graph_def(od_graph_def, name='')
  73.  
  74.  
  75. # Ce e aia Label map?
  76. # Label map este o lista de indici. Cand reteaua noastra neuronala
  77. # face o predictie, sa zicem 5, asta inseamna ca
  78. # a gasit un avion.
  79.  
  80. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  81. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
  82. category_index = label_map_util.create_category_index(categories)
  83.  
  84. # Sa initializam webcam-ul...
  85.  
  86. import cv2
  87. cap = cv2.VideoCapture(0)
  88.  
  89. # ... si sa punem tensorflow la treaba!
  90.  
  91. with detection_graph.as_default():
  92.   with tf.Session(graph=detection_graph) as sess:
  93.    ret = True
  94.    while (ret):
  95.       ret,image_np = cap.read()
  96.       # Trebuie sa expandam putin dimensiunile, din moment ce modelul
  97.       # se asteapta ca imaginile sa aiba o forma: [1, None, None, 3]
  98.       image_np_expanded = np.expand_dims(image_np, axis=0)
  99.       image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  100.       # Fiecare contur (box) reprezinta o parte a imaginii
  101.       # unde a fost detectat un obiect cunoscut
  102.       boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  103.       # Fiecare scor reprezinta nivelul de incredere,
  104.       # ca ce scrie e tot aia cu ce se vede
  105.       # Adica identificarea este corecta intr-o proportie de x%
  106.       # Scorul apare langa obiect impreuna cu denumirea obiectului (eticheta).
  107.       scores = detection_graph.get_tensor_by_name('detection_scores:0')
  108.       classes = detection_graph.get_tensor_by_name('detection_classes:0')
  109.       num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  110.       # Aici se produce detectia propriu-zisa
  111.       (boxes, scores, classes, num_detections) = sess.run(
  112.           [boxes, scores, classes, num_detections],
  113.           feed_dict={image_tensor: image_np_expanded})
  114.       # Aici se gaseste vizualizarea
  115.       vis_util.visualize_boxes_and_labels_on_image_array(
  116.           image_np,
  117.           np.squeeze(boxes),
  118.           np.squeeze(classes).astype(np.int32),
  119.           np.squeeze(scores),
  120.           category_index,
  121.           use_normalized_coordinates=True,
  122.           line_thickness=8)
  123.       cv2.imshow('image',cv2.resize(image_np,(1280,960)))
  124.       if cv2.waitKey(25) & 0xFF == ord('q'):
  125.           cv2.destroyAllWindows()
  126.           cap.release()
  127.           break
  128.      
  129. # Ura si la gara!
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top