Advertisement
baotrung217

inference2

Mar 15th, 2018
122
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
ABAP 8.72 KB | None | 0 0
  1. from __future__ import print_function
  2.  
  3. import argparse
  4. import os
  5. import sys
  6. import time
  7. import tensorflow as tf
  8. import numpy as np
  9. from scipy import misc
  10. import cv2
  11. from matplotlib import pyplot as plt
  12.  
  13. from model import PSPNet101, PSPNet50
  14. from tools import *
  15.  
  16. ADE20k_param = {'crop_size': [473, 473],
  17.                 'num_classes': 150,
  18.                 'model': PSPNet50}
  19. cityscapes_param = {'crop_size': [720, 720],
  20.                     'num_classes': 19,
  21.                     'model': PSPNet101}
  22.  
  23. SAVE_DIR = './output/'
  24. SNAPSHOT_DIR = './model/tmp/'
  25. VIDEO_PATH = './input/video_2.mp4'
  26.  
  27. def get_arguments():
  28.     parser = argparse.ArgumentParser(description="Reproduced PSPNet")
  29.     parser.add_argument("--img-path", type=str, default='',
  30.                         help="Path to the RGB image file.")
  31.     parser.add_argument("--checkpoints", type=str, default=SNAPSHOT_DIR,
  32.                         help="Path to restore weights.")
  33.     parser.add_argument("--save-dir", type=str, default=SAVE_DIR,
  34.                         help="Path to save output.")
  35.     parser.add_argument("--flipped-eval", action="store_true",
  36.                         help="whether to evaluate with flipped img.")
  37.     parser.add_argument("--dataset", type=str, default='cityscapes',
  38.                         choices=['ade20k', 'cityscapes'])
  39.     parser.add_argument("--source", type=str, default='video',
  40.                         choices=['camera', 'video'])
  41.  
  42.     return parser.parse_args()
  43.  
  44. def save(saver, sess, logdir, step):
  45.    model_name = 'model.ckpt'
  46.    checkpoint_path = os.path.join(logdir, model_name)
  47.  
  48.    if not os.path.exists(logdir):
  49.       os.makedirs(logdir)
  50.    saver.save(sess, checkpoint_path, global_step=step)
  51.    print('The checkpoint has been created.')
  52.  
  53. def load(saver, sess, ckpt_path):
  54.     saver.restore(sess, ckpt_path)
  55.     print("Restored model parameters from {}".format(ckpt_path))
  56.  
  57. def create_blank(height, width, rgb_color=(0, 0, 0)):
  58.     """Create new image(numpy array) filled with certain color in RGB"""
  59.     # Create black blank image
  60.     image = np.zeros((height, width, 3), np.uint8)
  61.  
  62.     # Since OpenCV uses BGR, convert the color first
  63.     color = tuple(reversed(rgb_color))
  64.     # Fill image with color
  65.     image[:] = color
  66.  
  67.     return image
  68.  
  69. def twoImgShowVertically(img1, img2, border):
  70.     h, w, channels = img1.shape
  71.  
  72.     img2 = cv2.resize(img2, (w, h))
  73.     both = create_blank(h + border * 2, w * 2 + border * 3)
  74.  
  75.     both[border:border + h, border:border + w] = img1.copy()
  76.     both[border:border + h, border * 2 + w:border * 2 + w * 2] = img2.copy()
  77.  
  78.     return both
  79.  
  80. def twoImgShowHorizontally(img1, img2, border):
  81.     h, w, channels = img1.shape
  82.  
  83.     img2 = cv2.resize(img2, (w, h))
  84.     both = create_blank(h * 2 + border * 3, w + border * 2)
  85.  
  86.     both[border:border + h, border:border + w] = img1.copy()
  87.     both[border * 2 + h:border * 2 + h * 2, border:border + w] = img2.copy()
  88.  
  89.     return both
  90.  
  91. def main():
  92.     args = get_arguments()
  93.  
  94.     ###########################################################
  95.     # load parameters
  96.     if args.dataset == 'ade20k':
  97.         param = ADE20k_param
  98.     elif args.dataset == 'cityscapes':
  99.         param = cityscapes_param
  100.  
  101.     crop_size = param['crop_size']
  102.     num_classes = param['num_classes']
  103.     PSPNet = param['model']
  104.  
  105.     # preprocess images
  106.     img = tf.placeholder(tf.uint8, [None, None, 3])
  107.     img_shape = tf.shape(img)
  108.     h, w = (tf.maximum(crop_size[0], img_shape[0]), tf.maximum(crop_size[1], img_shape[1]))
  109.     img = preprocess(img, h, w)
  110.  
  111.     # Create network.
  112.     net = PSPNet({'data': img}, is_training=False, num_classes=num_classes)
  113.     with tf.variable_scope('', reuse=True):
  114.         flipped_img = tf.image.flip_left_right(tf.squeeze(img))
  115.         flipped_img = tf.expand_dims(flipped_img, dim=0)
  116.         net2 = PSPNet({'data': flipped_img}, is_training=False, num_classes=num_classes)
  117.  
  118.     raw_output = net.layers['conv6']
  119.  
  120.     # Do flipped eval or not
  121.     if args.flipped_eval:
  122.         flipped_output = tf.image.flip_left_right(tf.squeeze(net2.layers['conv6']))
  123.         flipped_output = tf.expand_dims(flipped_output, dim=0)
  124.         raw_output = tf.add_n([raw_output, flipped_output])
  125.  
  126.     # Predictions.
  127.     raw_output_up = tf.image.resize_bilinear(raw_output, size=[h, w], align_corners=True)
  128.     raw_output_up = tf.image.crop_to_bounding_box(raw_output_up, 0, 0, img_shape[0], img_shape[1])
  129.     raw_output_up = tf.argmax(raw_output_up, axis=3)
  130.     pred = decode_labels(raw_output_up, img_shape, num_classes)
  131.  
  132.     # Init tf Session
  133.     config = tf.ConfigProto()
  134.     config.gpu_options.allow_growth = True
  135.     sess = tf.Session(config=config)
  136.     init = tf.global_variables_initializer()
  137.  
  138.     sess.run(init)
  139.  
  140.     restore_var = tf.global_variables()
  141.  
  142.     # Load checkpoint
  143.     ckpt = tf.train.get_checkpoint_state(args.checkpoints)
  144.     if ckpt and ckpt.model_checkpoint_path:
  145.         loader = tf.train.Saver(var_list=restore_var)
  146.         load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
  147.         load(loader, sess, ckpt.model_checkpoint_path)
  148.     else:
  149.         print('No checkpoint file found.')
  150.  
  151.     # Capture the video source
  152.     if args.source == 'camera':
  153.         cap = cv2.VideoCapture(0)
  154.     elif args.source == 'video':
  155.         cap = cv2.VideoCapture(VIDEO_PATH)
  156.  
  157.     # Check if camera opened successfully
  158.     if (cap.isOpened() == False):
  159.         print("Error opening video stream or file")
  160.  
  161.     # Read until video is completed
  162.     while(cap.isOpened()):
  163.         # Capture frame-by-frame
  164.         ret, frame = cap.read()
  165.         frameId = cap.get(1) # current frame number
  166.         frame = cv2.resize(frame, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC)
  167.         #frame_np_expanded = np.expand_dims(frame, axis=0)
  168.    
  169.         if ret == True:
  170.            
  171.             ##########################################################################
  172.             # Predict
  173.             #########################################################################
  174.             preds = sess.run(pred, feed_dict={img: frame})
  175.             ##########################################################################
  176.            
  177.             # Display the resulting frame
  178.             result = twoImgShowVertically(frame, preds[0], 50)
  179.             cv2.imshow('Video 2', result)
  180.             # Press ESC on keyboard to  exit
  181.             if cv2.waitKey(33) & 0xFF == ord('q'):
  182.                 break
  183.        
  184.         else:
  185.             break
  186.  
  187.     # When everything done, release the video capture object
  188.     cap.release()
  189.  
  190.     sess.close()
  191.  
  192.     # Closes all the frames
  193.     cv2.destroyAllWindows()
  194.     ###########################################################
  195.  
  196.     '''
  197.    # preprocess images
  198.    img, filename = load_img(args.img_path)
  199.    img_shape = tf.shape(img)
  200.    h, w = (tf.maximum(crop_size[0], img_shape[0]), tf.maximum(crop_size[1], img_shape[1]))
  201.    img = preprocess(img, h, w)
  202.  
  203.    # Create network.
  204.    net = PSPNet({'data': img}, is_training=False, num_classes=num_classes)
  205.    with tf.variable_scope('', reuse=True):
  206.        flipped_img = tf.image.flip_left_right(tf.squeeze(img))
  207.        flipped_img = tf.expand_dims(flipped_img, dim=0)
  208.        net2 = PSPNet({'data': flipped_img}, is_training=False, num_classes=num_classes)
  209.  
  210.    raw_output = net.layers['conv6']
  211.    
  212.    # Do flipped eval or not
  213.    if args.flipped_eval:
  214.        flipped_output = tf.image.flip_left_right(tf.squeeze(net2.layers['conv6']))
  215.        flipped_output = tf.expand_dims(flipped_output, dim=0)
  216.        raw_output = tf.add_n([raw_output, flipped_output])
  217.  
  218.    # Predictions.
  219.    raw_output_up = tf.image.resize_bilinear(raw_output, size=[h, w], align_corners=True)
  220.    raw_output_up = tf.image.crop_to_bounding_box(raw_output_up, 0, 0, img_shape[0], img_shape[1])
  221.    raw_output_up = tf.argmax(raw_output_up, axis=3)
  222.    pred = decode_labels(raw_output_up, img_shape, num_classes)
  223.    
  224.    # Init tf Session
  225.    config = tf.ConfigProto()
  226.    config.gpu_options.allow_growth = True
  227.    sess = tf.Session(config=config)
  228.    init = tf.global_variables_initializer()
  229.  
  230.    sess.run(init)
  231.    
  232.    restore_var = tf.global_variables()
  233.    
  234.    ckpt = tf.train.get_checkpoint_state(args.checkpoints)
  235.    if ckpt and ckpt.model_checkpoint_path:
  236.        loader = tf.train.Saver(var_list=restore_var)
  237.        load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
  238.        load(loader, sess, ckpt.model_checkpoint_path)
  239.    else:
  240.        print('No checkpoint file found.')
  241.    
  242.    preds = sess.run(pred)
  243.    
  244.    if not os.path.exists(args.save_dir):
  245.        os.makedirs(args.save_dir)
  246.    misc.imsave(args.save_dir + filename, preds[0])
  247.    '''
  248.    
  249. if __name__ == '__main__':
  250.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement