Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def run_inference(graph):
- results = {}
- with graph.as_default():
- with tf.Session() as sess:
- # Get handles to input and output tensors
- ops = tf.get_default_graph().get_operations()
- all_tensor_names = {output.name for op in ops for output in op.outputs}
- for i, image_path in enumerate(os.listdir(PATH_TO_TEST_IMAGES)):
- image = Image.open(os.path.join(PATH_TO_TEST_IMAGES, image_path))
- # the array based representation of the image will be used later in order to prepare the
- # result image with boxes and labels on it.
- image_np = load_image_into_numpy_array(image)
- # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
- image_np_expanded = np.expand_dims(image_np, axis=0)
- tensor_dict = {}
- for key in [
- 'num_detections', 'detection_boxes', 'detection_scores',
- 'detection_classes', 'detection_masks'
- ]:
- tensor_name = key + ':0'
- if tensor_name in all_tensor_names:
- tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
- tensor_name)
- if 'detection_masks' in tensor_dict:
- # The following processing is only for single image
- detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
- detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
- # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
- real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
- detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
- detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
- detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
- detection_masks, detection_boxes, image_np.shape[0], image_np.shape[1])
- detection_masks_reframed = tf.cast(
- tf.greater(detection_masks_reframed, 0.5), tf.uint8)
- # Follow the convention by adding back the batch dimension
- tensor_dict['detection_masks'] = tf.expand_dims(
- detection_masks_reframed, 0)
- image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
- # Run inference
- output_dict = sess.run(tensor_dict,
- feed_dict={image_tensor: np.expand_dims(image, 0)})
- # all outputs are float32 numpy arrays, so convert types as appropriate
- output_dict['num_detections'] = int(output_dict['num_detections'][0])
- output_dict['detection_classes'] = output_dict[
- 'detection_classes'][0].astype(np.uint8)
- output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
- output_dict['detection_scores'] = output_dict['detection_scores'][0]
- if 'detection_masks' in output_dict:
- output_dict['detection_masks'] = output_dict['detection_masks'][0]
- visualize_masks_and_labels_on_image_array(
- image_np,
- output_dict['detection_boxes'],
- output_dict['detection_classes'],
- output_dict['detection_scores'],
- category_index,
- instance_masks=output_dict.get('detection_masks'),
- use_normalized_coordinates=True,
- line_thickness=0)
- plt.figure(figsize=IMAGE_SIZE)
- plt.imshow(image_np)
- plt.imsave("{i}.png".format(i=i), image_np, format="png")
- results[i] = output_dict
- return results
- def visualize_masks_and_labels_on_image_array(
- image,
- boxes,
- classes,
- scores,
- category_index,
- instance_masks=None,
- instance_boundaries=None,
- keypoints=None,
- track_ids=None,
- use_normalized_coordinates=False,
- max_boxes_to_draw=20,
- min_score_thresh=.5,
- agnostic_mode=False,
- line_thickness=4,
- groundtruth_box_visualization_color='black',
- skip_scores=False,
- skip_labels=False,
- skip_track_ids=False):
- box_to_display_str_map = collections.defaultdict(list)
- box_to_color_map = collections.defaultdict(str)
- box_to_instance_masks_map = {}
- box_to_instance_boundaries_map = {}
- box_to_keypoints_map = collections.defaultdict(list)
- box_to_track_ids_map = {}
- if not max_boxes_to_draw:
- max_boxes_to_draw = boxes.shape[0]
- for i in range(min(max_boxes_to_draw, boxes.shape[0])):
- if scores is None or scores[i] > min_score_thresh:
- box = tuple(boxes[i].tolist())
- if instance_masks is not None:
- box_to_instance_masks_map[box] = instance_masks[i]
- if instance_boundaries is not None:
- box_to_instance_boundaries_map[box] = instance_boundaries[i]
- if keypoints is not None:
- box_to_keypoints_map[box].extend(keypoints[i])
- if track_ids is not None:
- box_to_track_ids_map[box] = track_ids[i]
- if scores is None:
- box_to_color_map[box] = groundtruth_box_visualization_color
- else:
- display_str = ''
- if not skip_labels:
- if not agnostic_mode:
- if classes[i] in category_index.keys():
- class_name = category_index[classes[i]]['name']
- else:
- class_name = 'N/A'
- display_str = str(class_name)
- if not skip_scores:
- if not display_str:
- display_str = '{}%'.format(int(100*scores[i]))
- else:
- display_str = '{}: {}%'.format(display_str, int(100*scores[i]))
- if not skip_track_ids and track_ids is not None:
- if not display_str:
- display_str = 'ID {}'.format(track_ids[i])
- else:
- display_str = '{}: ID {}'.format(display_str, track_ids[i])
- box_to_display_str_map[box].append(display_str)
- if agnostic_mode:
- box_to_color_map[box] = 'DarkOrange'
- elif track_ids is not None:
- prime_multipler = vis_util._get_multiplier_for_color_randomness()
- box_to_color_map[box] = vis_util.STANDARD_COLORS[
- (prime_multipler * track_ids[i]) % len(vis_util.STANDARD_COLORS)]
- else:
- box_to_color_map[box] = vis_util.STANDARD_COLORS[
- classes[i] % len(vis_util.STANDARD_COLORS)]
- # Draw all boxes onto image.
- for box, color in box_to_color_map.items():
- ymin, xmin, ymax, xmax = box
- if instance_masks is not None:
- vis_util.draw_mask_on_image_array(
- image,
- box_to_instance_masks_map[box],
- color=color
- )
- if instance_boundaries is not None:
- vis_util.draw_mask_on_image_array(
- image,
- box_to_instance_boundaries_map[box],
- color='red',
- alpha=1.0
- )
- if keypoints is not None:
- vis_util.draw_keypoints_on_image_array(
- image,
- box_to_keypoints_map[box],
- color=color,
- radius=line_thickness / 2,
- use_normalized_coordinates=use_normalized_coordinates)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement