Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # -*- coding: utf-8 -*-
- import time
- import tensorflow as tf
- slim = tf.contrib.slim
- from models.inception_resnet_v2.inception_resnet_v2 import inception_resnet_v2_arg_scope, inception_resnet_v2
- import os
- import json
- import lmdb
- from tensorflow.contrib.data import Dataset
- FLAGS = tf.app.flags.FLAGS
- tf.app.flags.DEFINE_string('ckpt', 'logs/20171114_111459/', 'Checkpoint file')
- tf.app.flags.DEFINE_string('lmdb_file', 'lmdbs/features/inception_tuned',
- 'Directory where to write lmdb with features')
- tf.app.flags.DEFINE_string('meta', 'meta', 'Meta dir')
- tf.app.flags.DEFINE_string('images_dir', 'data/images_299', 'Images dir')
- tf.app.flags.DEFINE_integer('batch_size', 512, 'Batch size')
- def _parse_function(filename):
- image_string = tf.read_file(filename)
- image_decoded = tf.image.decode_jpeg(image_string, channels=3)
- resized = tf.image.resize_image_with_crop_or_pad(image_decoded, 299, 299)
- resized = tf.to_float(resized)
- inputs_preprocess = resized / 255.0
- inputs_preprocess = inputs_preprocess - 0.5
- inputs_preprocess = inputs_preprocess * 2.0
- return filename, inputs_preprocess
- def main(argv=None):
- with open(os.path.join(FLAGS.meta, 'json', 'retrieval_dresses.json')) as f:
- elems = json.load(f)
- dresses = [x['photo'] for x in elems]
- filenames = [os.path.join(FLAGS.images_dir, '{}.jpg'.format(image)) for image in dresses]
- dataset = Dataset.from_tensor_slices(tf.constant(filenames))
- dataset = dataset.map(_parse_function)
- dataset = dataset.batch(FLAGS.batch_size)
- iterator = dataset.make_initializable_iterator()
- name, img = iterator.get_next()
- # Load the model
- with tf.Session() as sess:
- arg_scope = inception_resnet_v2_arg_scope()
- with slim.arg_scope(arg_scope):
- _, end_points = inception_resnet_v2(img, is_training=False)
- saver = tf.train.Saver()
- init = tf.global_variables_initializer()
- sess.run([init, iterator.initializer])
- saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt))
- if not os.path.exists(FLAGS.lmdb_file):
- os.makedirs(FLAGS.lmdb_file)
- env = lmdb.open(FLAGS.lmdb_file, map_size=64 * 1024 * 1024 * 1024)
- iter = 0
- while True:
- try:
- t0 = time.time()
- names, fvs = sess.run([name, end_points['PreLogitsFlatten']])
- infer_time = time.time() - t0
- with env.begin(write=True) as txn:
- for name_, fv in zip(names, fvs):
- name_ = str(name_).split('/')[-1].split('.')[0]
- txn.put(name_.encode('ascii'), fv.tobytes())
- duration = time.time() - t0
- iter += FLAGS.batch_size
- print('{} / {} (time: {:.4f} ({:.4f}))'.format(iter, len(filenames), duration, infer_time))
- except tf.errors.OutOfRangeError:
- break
- if __name__ == '__main__':
- tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement