Advertisement
Jeremiah_

cifar10_train.py

Feb 18th, 2020
203
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.83 KB | None | 0 0
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. #     http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15.  
  16. """A binary to train CIFAR-10 using a single GPU.
  17.  
  18. Accuracy:
  19. cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
  20. data) as judged by cifar10_eval.py.
  21.  
  22. Speed: With batch_size 128.
  23.  
  24. System        | Step Time (sec/batch)  |     Accuracy
  25. ------------------------------------------------------------------
  26. 1 Tesla K20m  | 0.35-0.60              | ~86% at 60K steps  (5 hours)
  27. 1 Tesla K40m  | 0.25-0.35              | ~86% at 100K steps (4 hours)
  28.  
  29. Usage:
  30. Please see the tutorial and website for how to download the CIFAR-10
  31. data set, compile the program and train the model.
  32.  
  33. http://tensorflow.org/tutorials/deep_cnn/
  34. """
  35. from __future__ import absolute_import
  36. from __future__ import division
  37. from __future__ import print_function
  38.  
  39. from datetime import datetime
  40. import time
  41.  
  42. import tensorflow as tf
  43.  
  44. import cifar10
  45.  
  46. FLAGS = tf.compat.v1.flags.FLAGS
  47.  
  48. tf.compat.v1.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
  49.                            """Directory where to write event logs """
  50.                            """and checkpoint.""")
  51. tf.compat.v1.flags.DEFINE_integer('max_steps', 100,
  52.                             """Number of batches to run.""")
  53. tf.compat.v1.flags.DEFINE_boolean('log_device_placement', False,
  54.                             """Whether to log device placement.""")
  55. tf.compat.v1.flags.DEFINE_integer('log_frequency', 10,
  56.                             """How often to log results to the console.""")
  57.  
  58. #config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=2, inter_op_parallelism_threads=2, allow_soft_placement=True, device_count={'CPU': 2})
  59. #session = tf.compat.v1.Session(config=config)
  60. #tf.compat.v1.keras.backend.set_session(session)
  61.  
  62.  
  63. def train():
  64.   """Train CIFAR-10 for a number of steps."""
  65.   with tf.Graph().as_default():
  66.     global_step = tf.compat.v1.train.get_or_create_global_step()
  67.  
  68.     # Get images and labels for CIFAR-10.
  69.     # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
  70.     # GPU and resulting in a slow down.
  71.     with tf.device('/cpu:0'):
  72.       images, labels = cifar10.distorted_inputs()
  73.  
  74.     # Build a Graph that computes the logits predictions from the
  75.     # inference model.
  76.     logits = cifar10.inference(images)
  77.  
  78.     # Calculate loss.
  79.     loss = cifar10.loss(logits, labels)
  80.  
  81.     # Build a Graph that trains the model with one batch of examples and
  82.     # updates the model parameters.
  83.     train_op = cifar10.train(loss, global_step)
  84.  
  85.     class _LoggerHook(tf.compat.v1.train.SessionRunHook):
  86.       """Logs loss and runtime."""
  87.  
  88.       def begin(self):
  89.         self._step = -1
  90.         self._start_time = time.time()
  91.  
  92.       def before_run(self, run_context):
  93.         self._step += 1
  94.         return tf.compat.v1.train.SessionRunArgs(loss)  # Asks for loss value.
  95.  
  96.       def after_run(self, run_context, run_values):
  97.         if self._step % FLAGS.log_frequency == 0:
  98.           current_time = time.time()
  99.           duration = current_time - self._start_time
  100.           self._start_time = current_time
  101.  
  102.           loss_value = run_values.results
  103.           examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
  104.           sec_per_batch = float(duration / FLAGS.log_frequency)
  105.  
  106.           format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
  107.                         'sec/batch)')
  108.           print (format_str % (datetime.now(), self._step, loss_value,
  109.                                examples_per_sec, sec_per_batch))
  110.  
  111.     with tf.compat.v1.train.MonitoredTrainingSession(
  112.         checkpoint_dir=FLAGS.train_dir,
  113.         hooks=[tf.compat.v1.train.StopAtStepHook(last_step=FLAGS.max_steps),
  114.                tf.compat.v1.train.NanTensorHook(loss),
  115.                _LoggerHook()],
  116.         config=tf.compat.v1.ConfigProto(
  117.             log_device_placement=FLAGS.log_device_placement)) as mon_sess:
  118.       while not mon_sess.should_stop():
  119.         mon_sess.run(train_op)
  120.  
  121.  
  122. def main(argv=None):  # pylint: disable=unused-argument
  123.   if tf.io.gfile.exists(FLAGS.train_dir):
  124.     tf.compat.v1.gfile.DeleteRecursively(FLAGS.train_dir)
  125.   tf.io.gfile.makedirs(FLAGS.train_dir)
  126.   train()
  127.  
  128.  
  129. if __name__ == '__main__':
  130.   tf.compat.v1.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement