Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """A binary to train CIFAR-10 using a single GPU.
- Accuracy:
- cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
- data) as judged by cifar10_eval.py.
- Speed: With batch_size 128.
- System | Step Time (sec/batch) | Accuracy
- ------------------------------------------------------------------
- 1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
- 1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
- Usage:
- Please see the tutorial and website for how to download the CIFAR-10
- data set, compile the program and train the model.
- http://tensorflow.org/tutorials/deep_cnn/
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from datetime import datetime
- import time
- import tensorflow as tf
- import cifar10
- FLAGS = tf.compat.v1.flags.FLAGS
- tf.compat.v1.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
- """Directory where to write event logs """
- """and checkpoint.""")
- tf.compat.v1.flags.DEFINE_integer('max_steps', 100,
- """Number of batches to run.""")
- tf.compat.v1.flags.DEFINE_boolean('log_device_placement', False,
- """Whether to log device placement.""")
- tf.compat.v1.flags.DEFINE_integer('log_frequency', 10,
- """How often to log results to the console.""")
- #config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=2, inter_op_parallelism_threads=2, allow_soft_placement=True, device_count={'CPU': 2})
- #session = tf.compat.v1.Session(config=config)
- #tf.compat.v1.keras.backend.set_session(session)
- def train():
- """Train CIFAR-10 for a number of steps."""
- with tf.Graph().as_default():
- global_step = tf.compat.v1.train.get_or_create_global_step()
- # Get images and labels for CIFAR-10.
- # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
- # GPU and resulting in a slow down.
- with tf.device('/cpu:0'):
- images, labels = cifar10.distorted_inputs()
- # Build a Graph that computes the logits predictions from the
- # inference model.
- logits = cifar10.inference(images)
- # Calculate loss.
- loss = cifar10.loss(logits, labels)
- # Build a Graph that trains the model with one batch of examples and
- # updates the model parameters.
- train_op = cifar10.train(loss, global_step)
- class _LoggerHook(tf.compat.v1.train.SessionRunHook):
- """Logs loss and runtime."""
- def begin(self):
- self._step = -1
- self._start_time = time.time()
- def before_run(self, run_context):
- self._step += 1
- return tf.compat.v1.train.SessionRunArgs(loss) # Asks for loss value.
- def after_run(self, run_context, run_values):
- if self._step % FLAGS.log_frequency == 0:
- current_time = time.time()
- duration = current_time - self._start_time
- self._start_time = current_time
- loss_value = run_values.results
- examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
- sec_per_batch = float(duration / FLAGS.log_frequency)
- format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
- 'sec/batch)')
- print (format_str % (datetime.now(), self._step, loss_value,
- examples_per_sec, sec_per_batch))
- with tf.compat.v1.train.MonitoredTrainingSession(
- checkpoint_dir=FLAGS.train_dir,
- hooks=[tf.compat.v1.train.StopAtStepHook(last_step=FLAGS.max_steps),
- tf.compat.v1.train.NanTensorHook(loss),
- _LoggerHook()],
- config=tf.compat.v1.ConfigProto(
- log_device_placement=FLAGS.log_device_placement)) as mon_sess:
- while not mon_sess.should_stop():
- mon_sess.run(train_op)
- def main(argv=None): # pylint: disable=unused-argument
- if tf.io.gfile.exists(FLAGS.train_dir):
- tf.compat.v1.gfile.DeleteRecursively(FLAGS.train_dir)
- tf.io.gfile.makedirs(FLAGS.train_dir)
- train()
- if __name__ == '__main__':
- tf.compat.v1.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement