SHARE
TWEET

Tensorflow Object Detection in Pictures

mdan Mar 16th, 2018 183 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1.  
  2. import numpy as np
  3. import os
  4. import six.moves.urllib as urllib
  5. import sys
  6. import tarfile
  7. import tensorflow as tf
  8. import zipfile
  9.  
  10. from collections import defaultdict
  11. from io import StringIO
  12. from matplotlib import pyplot as plt
  13. from PIL import Image
  14.  
  15.  
  16. from utils import label_map_util
  17.  
  18. from utils import visualization_utils as vis_util
  19.  
  20. MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
  21. MODEL_FILE = MODEL_NAME + '.tar.gz'
  22. DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
  23.  
  24.  
  25. PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
  26.  
  27. PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
  28.  
  29. NUM_CLASSES = 90
  30.  
  31. if not os.path.exists(MODEL_NAME + '/frozen_inference_graph.pb'):
  32.     print ('Downloading the model')
  33.     opener = urllib.request.URLopener()
  34.     opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
  35.     tar_file = tarfile.open(MODEL_FILE)
  36.     for file in tar_file.getmembers():
  37.       file_name = os.path.basename(file.name)
  38.       if 'frozen_inference_graph.pb' in file_name:
  39.         tar_file.extract(file, os.getcwd())
  40.     print ('Download complete')
  41. else:
  42.     print ('Model already exists')
  43.  
  44.  
  45. detection_graph = tf.Graph()
  46. with detection_graph.as_default():
  47.   od_graph_def = tf.GraphDef()
  48.   with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  49.     serialized_graph = fid.read()
  50.     od_graph_def.ParseFromString(serialized_graph)
  51.     tf.import_graph_def(od_graph_def, name='')
  52.  
  53. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  54. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
  55. category_index = label_map_util.create_category_index(categories)
  56.  
  57.  
  58. def load_image_into_numpy_array(image):
  59.   (im_width, im_height) = image.size
  60.   return np.array(image.getdata()).reshape(
  61.       (im_height, im_width, 3)).astype(np.uint8)
  62.  
  63.  
  64. PATH_TO_TEST_IMAGES_DIR = 'test_images'
  65. TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 4) ]
  66.  
  67. IMAGE_SIZE = (12, 8)
  68.  
  69. with detection_graph.as_default():
  70.   with tf.Session(graph=detection_graph) as sess:
  71.     for image_path in TEST_IMAGE_PATHS:
  72.       image = Image.open(image_path)
  73.       image_np = load_image_into_numpy_array(image)
  74.       image_np_expanded = np.expand_dims(image_np, axis=0)
  75.       image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  76.       boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  77.       scores = detection_graph.get_tensor_by_name('detection_scores:0')
  78.       classes = detection_graph.get_tensor_by_name('detection_classes:0')
  79.       num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  80.       (boxes, scores, classes, num_detections) = sess.run(
  81.           [boxes, scores, classes, num_detections],
  82.           feed_dict={image_tensor: image_np_expanded})
  83.       vis_util.visualize_boxes_and_labels_on_image_array(
  84.           image_np,
  85.           np.squeeze(boxes),
  86.           np.squeeze(classes).astype(np.int32),
  87.           np.squeeze(scores),
  88.           category_index,
  89.           use_normalized_coordinates=True,
  90.           line_thickness=8)
  91.       plt.figure(figsize=IMAGE_SIZE)
  92.       plt.imshow(image_np)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top