Advertisement
Guest User

Untitled

a guest
Jul 27th, 2016
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.95 KB | None | 0 0
  1.  
  2. import numpy as np
  3. import tensorflow as tf
  4.  
  5. FLAGS = tf.app.flags.FLAGS
  6.  
  7. tf.app.flags.DEFINE_integer('max_steps', 1000, "Number of rounds of testing to run.")
  8. tf.app.flags.DEFINE_integer('batch_size', 256, "Number of images to process in a batch.")
  9. tf.app.flags.DEFINE_integer('num_gpus', 4, """How many GPUs to use.""")
  10. tf.app.flags.DEFINE_boolean('log_device_placement', False, "Whether to log device placement.")
  11.  
  12.  
  13.  
  14. def train():
  15.     with tf.Graph().as_default(), tf.device('/cpu:0'):
  16.         images, labels = inputs(mode='train', batch_size=FLAGS.batch_size*FLAGS.num_gpus, no_threads=4)
  17.  
  18.         # Build an initialization operation
  19.         init = tf.initialize_all_variables()
  20.  
  21.         # Start running the operations
  22.         sess = tf.Session(config=tf.ConfigProto(                          
  23.                     allow_soft_placement=True,
  24.                     log_device_placement=FLAGS.log_device_placement))
  25.  
  26.         sess.run(init)
  27.            
  28.         # Start queue runners
  29.         tf.train.start_queue_runners(sess=sess)
  30.  
  31.         # Iterate through testing steps
  32.         for step in range(FLAGS.max_steps):
  33.             # Do one step of testing and time it
  34.             start_time = time.time()
  35.             _ = sess.run(images)
  36.             duration = time.time() - start_time
  37.             print(duration)
  38.  
  39.  
  40.  
  41.  
  42. def _read_and_decode(filename_queue):
  43.     reader = tf.TFRecordReader()
  44.     _, serialized_example = reader.read(filename_queue)
  45.     features = tf.parse_single_example(
  46.         serialized_example,
  47.  
  48.         features={
  49.             'label': tf.FixedLenFeature([], tf.int64),
  50.             'image_raw': tf.FixedLenFeature([], tf.string)
  51.         })
  52.  
  53.     image = tf.image.decode_jpeg(features['image_raw'], channels=3)
  54.     image = tf.random_crop(image, [IMAGE_SIZE_DISTORTED, IMAGE_SIZE_DISTORTED,
  55.                                    IMAGE_DEPTH])
  56.  
  57.     distorted_image = tf.cast(image, tf.float32)
  58.  
  59.     return distorted_image, features['label']
  60.  
  61.  
  62.  
  63.  
  64.  
  65.  
  66. def _input_pipeline(filenames, batch_size, num_epochs=None, shuffle=False,
  67.                     no_threads=1):
  68.     filename_queue = tf.train.string_input_producer(filenames,
  69.                                                     num_epochs=num_epochs,
  70.                                                     shuffle=False)
  71.  
  72.     example, label = _read_and_decode(filename_queue=filename_queue)
  73.  
  74.     min_after_dequeue = 1000
  75.     capacity = 1000 + 4 * batch_size
  76.  
  77.     if shuffle:
  78.         example_batch, label_batch = tf.train.shuffle_batch([example, label],
  79.                                                             batch_size=batch_size,
  80.                                                             capacity=capacity,
  81.                                                             min_after_dequeue=min_after_dequeue,
  82.                                                             num_threads=no_threads)
  83.     else:
  84.         example_batch, label_batch = tf.train.batch([example, label],
  85.                                                     batch_size=batch_size,
  86.                                                     capacity=capacity,
  87.                                                     num_threads=no_threads)
  88.  
  89.     return example_batch, label_batch
  90.  
  91. def inputs(mode="train", batch_size=64, no_threads=5):
  92.     if mode == "train":
  93.         files = os.listdir(TF_RECORDS_DIR_TRAIN)
  94.         filenames = [os.path.join(TF_RECORDS_DIR_TRAIN, file) for file in files]
  95.         shuffle = True
  96.     elif mode == "eval":
  97.         files = os.listdir(TF_RECORDS_DIR_EVAL)
  98.         filenames = [os.path.join(TF_RECORDS_DIR_EVAL, file) for file in files]
  99.         shuffle = False
  100.     else:
  101.         raise NotImplementedError("Please specify supported mode. Supported mode paramater values are: train and eval")
  102.        
  103.     return _input_pipeline(filenames=filenames, batch_size=batch_size, shuffle=shuffle, no_threads=no_threads)
  104.  
  105.        
  106. def main(argv=None):
  107.     train()
  108.  
  109.  
  110. if __name__ == '__main__':
  111.     tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement