Advertisement
Guest User

Untitled

a guest
Nov 20th, 2017
67
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.02 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. import time
  3.  
  4. import tensorflow as tf
  5.  
  6. slim = tf.contrib.slim
  7. from models.inception_resnet_v2.inception_resnet_v2 import inception_resnet_v2_arg_scope, inception_resnet_v2
  8. import os
  9. import json
  10. import lmdb
  11. from tensorflow.contrib.data import Dataset
  12.  
  13. FLAGS = tf.app.flags.FLAGS
  14.  
  15. tf.app.flags.DEFINE_string('ckpt', 'logs/20171114_111459/', 'Checkpoint file')
  16. tf.app.flags.DEFINE_string('lmdb_file', 'lmdbs/features/inception_tuned',
  17. 'Directory where to write lmdb with features')
  18. tf.app.flags.DEFINE_string('meta', 'meta', 'Meta dir')
  19. tf.app.flags.DEFINE_string('images_dir', 'data/images_299', 'Images dir')
  20. tf.app.flags.DEFINE_integer('batch_size', 512, 'Batch size')
  21.  
  22.  
  23. def _parse_function(filename):
  24. image_string = tf.read_file(filename)
  25. image_decoded = tf.image.decode_jpeg(image_string, channels=3)
  26. resized = tf.image.resize_image_with_crop_or_pad(image_decoded, 299, 299)
  27. resized = tf.to_float(resized)
  28. inputs_preprocess = resized / 255.0
  29. inputs_preprocess = inputs_preprocess - 0.5
  30. inputs_preprocess = inputs_preprocess * 2.0
  31. return filename, inputs_preprocess
  32.  
  33.  
  34. def main(argv=None):
  35. with open(os.path.join(FLAGS.meta, 'json', 'retrieval_dresses.json')) as f:
  36. elems = json.load(f)
  37. dresses = [x['photo'] for x in elems]
  38.  
  39. filenames = [os.path.join(FLAGS.images_dir, '{}.jpg'.format(image)) for image in dresses]
  40.  
  41. dataset = Dataset.from_tensor_slices(tf.constant(filenames))
  42. dataset = dataset.map(_parse_function)
  43. dataset = dataset.batch(FLAGS.batch_size)
  44.  
  45. iterator = dataset.make_initializable_iterator()
  46. name, img = iterator.get_next()
  47.  
  48. # Load the model
  49. with tf.Session() as sess:
  50. arg_scope = inception_resnet_v2_arg_scope()
  51.  
  52. with slim.arg_scope(arg_scope):
  53. _, end_points = inception_resnet_v2(img, is_training=False)
  54. saver = tf.train.Saver()
  55. init = tf.global_variables_initializer()
  56. sess.run([init, iterator.initializer])
  57. saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt))
  58.  
  59. if not os.path.exists(FLAGS.lmdb_file):
  60. os.makedirs(FLAGS.lmdb_file)
  61. env = lmdb.open(FLAGS.lmdb_file, map_size=64 * 1024 * 1024 * 1024)
  62.  
  63. iter = 0
  64. while True:
  65. try:
  66. t0 = time.time()
  67. names, fvs = sess.run([name, end_points['PreLogitsFlatten']])
  68. infer_time = time.time() - t0
  69.  
  70. with env.begin(write=True) as txn:
  71. for name_, fv in zip(names, fvs):
  72. name_ = str(name_).split('/')[-1].split('.')[0]
  73. txn.put(name_.encode('ascii'), fv.tobytes())
  74. duration = time.time() - t0
  75. iter += FLAGS.batch_size
  76. print('{} / {} (time: {:.4f} ({:.4f}))'.format(iter, len(filenames), duration, infer_time))
  77. except tf.errors.OutOfRangeError:
  78. break
  79.  
  80.  
  81. if __name__ == '__main__':
  82. tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement