Advertisement
lamiastella

pointnet feature extraction

Jun 7th, 2021
1,358
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.89 KB | None | 0 0
  1. import argparse
  2. import math
  3. import h5py
  4. import numpy as np
  5. #import tensorflow as tf
  6. import tensorflow.compat.v1 as tf
  7. tf.disable_v2_behavior()
  8. import socket
  9. import importlib
  10. import os
  11. import sys
  12. #import cPickle as pickle
  13. import pickle
  14. BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  15. sys.path.append(BASE_DIR)
  16. sys.path.append(os.path.join(BASE_DIR, 'models'))
  17. sys.path.append(os.path.join(BASE_DIR, 'utils'))
  18. import tf_util
  19.  
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
  22. parser.add_argument('--model', default='pointnet_hico', help='Model name: pointnet_cls or pointnet_cls_basic [default: pointnet_cls]')
  23. parser.add_argument('--num_point', type=int, default=1228, help='Point Number [256/512/1024/2048] [default: 1024]')
  24. parser.add_argument('--model_path', default='log/model.ckpt', help='model checkpoint file path [default: log/model.ckpt]')
  25. parser.add_argument('--input_list', default='./', help='Path list of your point cloud files [default: ./pc_list.txt]')
  26. FLAGS = parser.parse_args()
  27.  
  28.  
  29. NUM_POINT = FLAGS.num_point
  30. GPU_INDEX = FLAGS.gpu
  31. MODEL_PATH = FLAGS.model_path
  32. BATCH_SIZE = 1
  33. MODEL = importlib.import_module(FLAGS.model) # import network module
  34. MODEL_FILE = os.path.join(BASE_DIR, 'models', FLAGS.model+'.py')
  35.  
  36. MAX_NUM_POINT = 1228
  37. NUM_CLASSES = 600
  38.  
  39. HOSTNAME = socket.gethostname()
  40. print('HOSTNAME: ', HOSTNAME)
  41.  
  42. def evaluate():
  43.     with tf.device('/gpu:'+str(GPU_INDEX)):
  44.         pointclouds_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
  45.         is_training_pl = tf.placeholder(tf.bool, shape=())
  46.  
  47.  
  48.         # simple model
  49.         feat = MODEL.get_model(pointclouds_pl, is_training_pl)
  50.        
  51.         # Add ops to save and restore all the variables.
  52.         saver = tf.train.Saver()
  53.        
  54.     # Create a session
  55.     config = tf.ConfigProto()
  56.     config.gpu_options.allow_growth = True
  57.     config.allow_soft_placement = True
  58.     config.log_device_placement = True
  59.     sess = tf.Session(config=config)
  60.  
  61.     # Restore variables from disk.
  62.     saver.restore(sess, MODEL_PATH)
  63.  
  64.     ops = {'pointclouds_pl': pointclouds_pl,
  65.            'is_training_pl': is_training_pl,
  66.            'feat': feat}
  67.  
  68.     eval_one_epoch(sess, ops)
  69.  
  70.    
  71. def eval_one_epoch(sess, ops):
  72.     is_training = False
  73.     input_list = None
  74.     with open(FLAGS.input_list, 'r') as f:
  75.         input_list = f.readlines()
  76.    
  77.     for fn in range(len(input_list)):
  78.         current_data = pickle.load(open(fn, 'rb'))
  79.         current_data = current_data[None, :NUM_POINT, :]
  80.        
  81.            
  82.         feed_dict = {ops['pointclouds_pl']: current_data,
  83.                      ops['is_training_pl']: is_training}
  84.         feat = sess.run([ops['feat']], feed_dict=feed_dict)
  85.         print('filename: ', fn)
  86.         pickle.dump(feat, open(fn[:-4] + '_feature.pkl', 'wb'))
  87.  
  88. with tf.Graph().as_default():
  89.     evaluate()
  90.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement