Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class NodeLookup(object):
- def __init__(self,
- label_lookup_path=None,
- uid_lookup_path=None):
- if not label_lookup_path:
- label_lookup_path = os.path.join(
- model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
- if not uid_lookup_path:
- uid_lookup_path = os.path.join(
- model_dir, 'imagenet_synset_to_human_label_map.txt')
- self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
- def load(self, label_lookup_path, uid_lookup_path):
- if not tf.gfile.Exists(uid_lookup_path):
- tf.logging.fatal('File does not exist %s', uid_lookup_path)
- if not tf.gfile.Exists(label_lookup_path):
- tf.logging.fatal('File does not exist %s', label_lookup_path)
- # Loads mapping from string UID to human-readable string
- proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
- uid_to_human = {}
- p = re.compile(r'[n\d]*[ \S,]*')
- for line in proto_as_ascii_lines:
- parsed_items = p.findall(line)
- uid = parsed_items[0]
- human_string = parsed_items[2]
- uid_to_human[uid] = human_string
- # Loads mapping from string UID to integer node ID.
- node_id_to_uid = {}
- proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
- for line in proto_as_ascii:
- if line.startswith(' target_class:'):
- target_class = int(line.split(': ')[1])
- if line.startswith(' target_class_string:'):
- target_class_string = line.split(': ')[1]
- node_id_to_uid[target_class] = target_class_string[1:-2]
- # Loads the final mapping of integer node ID to human-readable string
- node_id_to_name = {}
- for key, val in node_id_to_uid.items():
- if val not in uid_to_human:
- tf.logging.fatal('Failed to locate: %s', val)
- name = uid_to_human[val]
- node_id_to_name[key] = name
- return node_id_to_name
- def id_to_string(self, node_id):
- if node_id not in self.node_lookup:
- return ''
- return self.node_lookup[node_id]
- def create_graph():
- # Creates graph from saved graph_def.pb.
- with tf.gfile.FastGFile(os.path.join(
- model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- _ = tf.import_graph_def(graph_def, name='')
- def maybe_download_and_extract():
- # Download and extract model tar file
- dest_directory = model_dir
- if not os.path.exists(dest_directory):
- os.makedirs(dest_directory)
- filename = DATA_URL.split('/')[-1]
- filepath = os.path.join(dest_directory, filename)
- if not os.path.exists(filepath):
- def _progress(count, block_size, total_size):
- sys.stdout.write('\r>> Downloading %s %.1f%%' % (
- filename, float(count * block_size) / float(total_size) * 100.0))
- sys.stdout.flush()
- filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
- print()
- statinfo = os.stat(filepath)
- print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
- tarfile.open(filepath, 'r:gz').extractall(dest_directory)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement