Advertisement
Guest User

Untitled

a guest
Mar 28th, 2017
54
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.99 KB | None | 0 0
  1. class NodeLookup(object):
  2. def __init__(self,
  3. label_lookup_path=None,
  4. uid_lookup_path=None):
  5. if not label_lookup_path:
  6. label_lookup_path = os.path.join(
  7. model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
  8. if not uid_lookup_path:
  9. uid_lookup_path = os.path.join(
  10. model_dir, 'imagenet_synset_to_human_label_map.txt')
  11. self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
  12.  
  13. def load(self, label_lookup_path, uid_lookup_path):
  14.  
  15. if not tf.gfile.Exists(uid_lookup_path):
  16. tf.logging.fatal('File does not exist %s', uid_lookup_path)
  17. if not tf.gfile.Exists(label_lookup_path):
  18. tf.logging.fatal('File does not exist %s', label_lookup_path)
  19.  
  20. # Loads mapping from string UID to human-readable string
  21. proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
  22. uid_to_human = {}
  23. p = re.compile(r'[n\d]*[ \S,]*')
  24. for line in proto_as_ascii_lines:
  25. parsed_items = p.findall(line)
  26. uid = parsed_items[0]
  27. human_string = parsed_items[2]
  28. uid_to_human[uid] = human_string
  29.  
  30. # Loads mapping from string UID to integer node ID.
  31. node_id_to_uid = {}
  32. proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
  33. for line in proto_as_ascii:
  34. if line.startswith(' target_class:'):
  35. target_class = int(line.split(': ')[1])
  36. if line.startswith(' target_class_string:'):
  37. target_class_string = line.split(': ')[1]
  38. node_id_to_uid[target_class] = target_class_string[1:-2]
  39.  
  40. # Loads the final mapping of integer node ID to human-readable string
  41. node_id_to_name = {}
  42. for key, val in node_id_to_uid.items():
  43. if val not in uid_to_human:
  44. tf.logging.fatal('Failed to locate: %s', val)
  45. name = uid_to_human[val]
  46. node_id_to_name[key] = name
  47.  
  48. return node_id_to_name
  49.  
  50. def id_to_string(self, node_id):
  51. if node_id not in self.node_lookup:
  52. return ''
  53. return self.node_lookup[node_id]
  54.  
  55.  
  56. def create_graph():
  57.  
  58. # Creates graph from saved graph_def.pb.
  59. with tf.gfile.FastGFile(os.path.join(
  60. model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
  61. graph_def = tf.GraphDef()
  62. graph_def.ParseFromString(f.read())
  63. _ = tf.import_graph_def(graph_def, name='')
  64.  
  65.  
  66. def maybe_download_and_extract():
  67. # Download and extract model tar file
  68. dest_directory = model_dir
  69. if not os.path.exists(dest_directory):
  70. os.makedirs(dest_directory)
  71. filename = DATA_URL.split('/')[-1]
  72. filepath = os.path.join(dest_directory, filename)
  73. if not os.path.exists(filepath):
  74. def _progress(count, block_size, total_size):
  75. sys.stdout.write('\r>> Downloading %s %.1f%%' % (
  76. filename, float(count * block_size) / float(total_size) * 100.0))
  77. sys.stdout.flush()
  78. filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
  79. print()
  80. statinfo = os.stat(filepath)
  81. print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  82. tarfile.open(filepath, 'r:gz').extractall(dest_directory)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement