Advertisement
Guest User

Untitled

a guest
Jun 17th, 2018
296
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.70 KB | None | 0 0
  1. import _init_paths
  2. import tensorflow as tf
  3. from fast_rcnn.config import cfg
  4. from fast_rcnn.test import im_detect
  5. from fast_rcnn.nms_wrapper import nms
  6. from utils.timer import Timer
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import os, sys, cv2
  10. import argparse
  11. from networks.factory import get_network
  12. from os import listdir
  13. from os.path import isfile, join
  14.  
  15.  
  16. CLASSES = ('__background__',  # always index 0
  17.            'plate', 'piece of yarn', 'hand', 'car toy',
  18.            'utensils', 'textured block', 'dump truck', 'board book', 'multiple pop-up',
  19.            'cow', 'toy telephone', 'baby doll', 'jack-in-the-box',
  20.            'head', 'child', 'cup',
  21.            'medium-sized ball', 'small ball', '8 letter block', 'music box')
  22.  
  23.  
  24. #CLASSES = ('__background__','person','bike','motorbike','car','bus')
  25.  
  26. def vis_detections(im, class_name, dets,ax, thresh=0.5):
  27.     """Draw detected bounding boxes."""
  28.     inds = np.where(dets[:, -1] >= thresh)[0]
  29.     if len(inds) == 0:
  30.         return
  31.  
  32.     for i in inds:
  33.         bbox = dets[i, :4]
  34.         score = dets[i, -1]
  35.  
  36.         ax.add_patch(
  37.             plt.Rectangle((bbox[0], bbox[1]),
  38.                           bbox[2] - bbox[0],
  39.                           bbox[3] - bbox[1], fill=False,
  40.                           edgecolor='red', linewidth=3.5)
  41.             )
  42.         ax.text(bbox[0], bbox[1] - 2,
  43.                 '{:s} {:.3f}'.format(class_name, score),
  44.                 bbox=dict(facecolor='blue', alpha=0.5),
  45.                 fontsize=14, color='white')
  46.  
  47.     ax.set_title(('{} detections with '
  48.                   'p({} | box) >= {:.1f}').format(class_name, class_name,
  49.                                                   thresh),
  50.                   fontsize=14)
  51.     plt.axis('off')
  52.     plt.tight_layout()
  53.     plt.draw()
  54.  
  55.  
  56. def demo(sess, net, image_name):
  57.     """Detect object classes in an image using pre-computed object proposals."""
  58.  
  59.     # Load the demo image
  60.     im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
  61.     #im_file = os.path.join('/home/corgi/Lab/label/pos_frame/ACCV/training/000001/',image_name)
  62.     im = cv2.imread(im_file)
  63.  
  64.     # Detect all object classes and regress object bounds
  65.     timer = Timer()
  66.     timer.tic()
  67.     scores, boxes = im_detect(sess, net, im)
  68.     timer.toc()
  69.     print ('Detection took {:.3f}s for '
  70.            '{:d} object proposals').format(timer.total_time, boxes.shape[0])
  71.  
  72.     # Visualize detections for each class
  73.     im = im[:, :, (2, 1, 0)]
  74.     fig, ax = plt.subplots(figsize=(12, 12))
  75.     ax.imshow(im, aspect='equal')
  76.  
  77.     CONF_THRESH = 0.8
  78.     NMS_THRESH = 0.3
  79.     for cls_ind, cls in enumerate(CLASSES[1:]):
  80.         cls_ind += 1 # because we skipped background
  81.         cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
  82.         cls_scores = scores[:, cls_ind]
  83.         dets = np.hstack((cls_boxes,
  84.                           cls_scores[:, np.newaxis])).astype(np.float32)
  85.         keep = nms(dets, NMS_THRESH)
  86.         dets = dets[keep, :]
  87.         vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)
  88.  
  89. def parse_args():
  90.     """Parse input arguments."""
  91.     parser = argparse.ArgumentParser(description='Faster R-CNN demo')
  92.     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
  93.                         default=0, type=int)
  94.     parser.add_argument('--cpu', dest='cpu_mode',
  95.                         help='Use CPU mode (overrides --gpu)',
  96.                         action='store_true')
  97.     parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
  98.                         default='VGGnet_test')
  99.     parser.add_argument('--model', dest='model', help='Model path',
  100.                         default=' ')
  101.  
  102.     args = parser.parse_args()
  103.  
  104.     return args
  105. if __name__ == '__main__':
  106.     cfg.TEST.HAS_RPN = True  # Use RPN for proposals
  107.  
  108.     args = parse_args()
  109.  
  110.     if args.model == ' ':
  111.         raise IOError(('Error: Model not found.\n'))
  112.        
  113.     # init session
  114.     sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  115.     # load network
  116.     net = get_network(args.demo_net)
  117.     # load model
  118.     saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)
  119.     saver.restore(sess, args.model)
  120.    
  121.     #sess.run(tf.initialize_all_variables())
  122.  
  123.     print '\n\nLoaded network {:s}'.format(args.model)
  124.  
  125.     # Warmup on a dummy image
  126.     im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
  127.     for i in xrange(2):
  128.         _, _= im_detect(sess, net, im)
  129.  
  130.    
  131.     onlyfiles = [f for f in listdir('./data/demo') if isfile(join('./data/demo', f))]
  132.     for im_name in onlyfiles:
  133.         print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  134.         print 'Demo for data/demo/{}'.format(im_name)
  135.         demo(sess, net, im_name)
  136.  
  137.     plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement