Advertisement
Guest User

Untitled

a guest
Dec 9th, 2019
145
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.35 KB | None | 0 0
  1. import os
  2. import argparse
  3. import numpy as np
  4. import tensorflow as tf
  5.  
  6. from scipy.misc import imread, imsave, imresize
  7. from matplotlib import pyplot as plt
  8.  
  9. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  10.  
  11. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  12.  
  13. # input image path
  14. parser = argparse.ArgumentParser()
  15.  
  16. parser.add_argument('--im_path', type=str, default='./demo/45765448.jpg',
  17.                     help='input image paths.')
  18.  
  19. # color map
  20. floorplan_map = {
  21.     0: [255,255,255], # background
  22.     1: [192,192,224], # closet
  23.     2: [192,255,255], # batchroom/washroom
  24.     3: [224,255,192], # livingroom/kitchen/dining room
  25.     4: [255,224,128], # bedroom
  26.     5: [255,160, 96], # hall
  27.     6: [255,224,224], # balcony
  28.     7: [255,255,255], # not used
  29.     8: [255,255,255], # not used
  30.     9: [255, 60,128], # door & window
  31.     10:[  0,  0,  0]  # wall
  32. }
  33.  
  34. def ind2rgb(ind_im, color_map=floorplan_map):
  35.     rgb_im = np.zeros((ind_im.shape[0], ind_im.shape[1], 3))
  36.  
  37.     for i, rgb in color_map.iteritems():
  38.         rgb_im[(ind_im==i)] = rgb
  39.  
  40.     return rgb_im
  41.  
  42. def main(args):
  43.     # load input
  44.     im = imageio.imread(args.im_path, mode='RGB')
  45.     im = im.astype(np.float32)
  46.     im = skimage.transform.resize(im, (512,512,3)) / 255.
  47.  
  48.     # create tensorflow session
  49.     with tf.Session() as sess:
  50.        
  51.         # initialize
  52.         sess.run(tf.group(tf.global_variables_initializer(),
  53.                     tf.local_variables_initializer()))
  54.  
  55.         # restore pretrained model
  56.         saver = tf.train.import_meta_graph('./pretrained/pretrained_r3d.meta')
  57.         saver.restore(sess, './pretrained/pretrained_r3d')
  58.  
  59.         # get default graph
  60.         graph = tf.get_default_graph()
  61.  
  62.         # restore inputs & outpus tensor
  63.         x = graph.get_tensor_by_name('inputs:0')
  64.         room_type_logit = graph.get_tensor_by_name('Cast:0')
  65.         room_boundary_logit = graph.get_tensor_by_name('Cast_1:0')
  66.  
  67.         # infer results
  68.         [room_type, room_boundary] = sess.run([room_type_logit, room_boundary_logit],\
  69.                                         feed_dict={x:im.reshape(1,512,512,3)})
  70.         room_type, room_boundary = np.squeeze(room_type), np.squeeze(room_boundary)
  71.  
  72.         # merge results
  73.         floorplan = room_type.copy()
  74.         floorplan[room_boundary==1] = 9
  75.         floorplan[room_boundary==2] = 10
  76.         floorplan_rgb = ind2rgb(floorplan)
  77.  
  78.         # plot results
  79.         plt.subplot(121)
  80.         plt.imshow(im)
  81.         plt.subplot(122)
  82.         plt.imshow(floorplan_rgb/255.)
  83.         plt.show()
  84.  
  85. if __name__ == '__main__':
  86.     FLAGS, unparsed = parser.parse_known_args()
  87.     main(FLAGS)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement