Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """A binary to train CIFAR-10 using a single GPU.
- Accuracy:
- cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
- data) as judged by cifar10_eval.py.
- Speed: With batch_size 128.
- System | Step Time (sec/batch) | Accuracy
- ------------------------------------------------------------------
- 1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
- 1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
- Usage:
- Please see the tutorial and website for how to download the CIFAR-10
- data set, compile the program and train the model.
- http://tensorflow.org/tutorials/deep_cnn/
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from datetime import datetime
- import os.path
- import time
- import pycurl
- import re
- from StringIO import StringIO
- import tinys3
- import os
- import sys
- import random, string
- import numpy as np
- from six.moves import xrange # pylint: disable=redefined-builtin
- import tensorflow as tf
- import pprint
- import cifar10
- import pprint
- from tensorflow.python.tools import inspect_checkpoint as chkp
- interrupt_check_url = "http://169.254.169.254/latest/meta-data/spot/termination-time"
- FLAGS = tf.app.flags.FLAGS
- tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
- """Directory where to write event logs """
- """and checkpoint.""")
- tf.app.flags.DEFINE_integer('max_steps', sys.maxint,
- """Number of batches to run.""")
- tf.app.flags.DEFINE_boolean('log_device_placement', False,
- """Whether to log device placement.""")
- tf.app.flags.DEFINE_string('checkpoint_dir', None,
- """Checkpoint file path to start training""")
- def train():
- """Train CIFAR-10 for a number of steps."""
- with tf.Graph().as_default():
- global_step = tf.Variable(0, trainable=False)
- images, labels = cifar10.distorted_inputs() # Get images and labels for CIFAR-10.
- logits = cifar10.inference(images) # Build a Graph that computes the logits predictions from the
- # inference model.
- loss = cifar10.loss(logits, labels) # Calculate loss.
- train_op = cifar10.train(loss, global_step) # Build a Graph that trains the model with one batch of examples and updates the model parameters.
- # print (tf.global_variables())
- saver = tf.train.Saver(tf.global_variables()) # Create a saver.
- summary_op = tf.summary.merge_all() # Build the summary operation based on the TF collection of Summaries.
- init = tf.global_variables_initializer() # Build an initialization operation to run below.
- sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
- sess.run(init) # Start running operations on the Graph.
- if FLAGS.checkpoint_dir is not None:
- ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
- print("checkpoint path is %s" % ckpt.model_checkpoint_path)
- tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)
- # Start the queue runners.
- print("FLAGS.checkpoint_dir is %s" % FLAGS.checkpoint_dir)
- tf.train.start_queue_runners(sess=sess)
- summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
- cur_step = sess.run(global_step);
- print("current step is %s" % cur_step)
- interrupt_check_duration = 0.0
- elapsed_time = time.time()
- flag = 0
- for step in xrange(cur_step, FLAGS.max_steps):
- _, loss_value = sess.run([train_op, loss])
- assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
- # if step % 10 == 0:
- # print (step)
- # num_examples_per_step = FLAGS.batch_size
- # examples_per_sec = num_examples_per_step / duration
- # sec_per_batch = float(duration)
- # format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
- # 'sec/batch)')
- # print(format_str % (datetime.now(), step, loss_value,
- # examples_per_sec, sec_per_batch))
- if step % 10 == 0:
- # print (step)
- # summary_str = sess.run(summary_op)
- # # print (summary_str)
- # # break
- # summary_writer.add_summary(summary_str, step) # Adds a Summary protocol buffer to the event file.
- # # print (dir(summary_writer))
- # # print (summary_writer.get_logdir)
- # # print (sess.__class__)
- # reader = tf.train.NewCheckpointReader("/tmp/ckpt/model.ckpt-70000")
- # print (reader)
- # variables = reader.get_variable_to_shape_map()
- # print (variables)
- # # for ele in variables:
- # # print (ele)
- # print ("\n")
- saver.restore(sess, "/tmp/ckpt/model.ckpt-60000")
- tvars = tf.trainable_variables()
- tvars_vals = sess.run(tvars)
- for var, val in zip(tvars, tvars_vals):
- # print(var.name, val) # Prints the name of the variable alongside its value.
- print(var.name) # Prints the name of the variable alongside its value.
- print(val)
- break
- # Save the model checkpoint periodically.
- # if step % 100 == 0 or (step + 1) == FLAGS.max_steps:
- # with tf.variable_scope('conv1') as scope:
- # # print (scope.get_variable())
- # # print (scope)
- # tf.get_variable_scope().reuse_variables()
- # w = tf.get_variable('weights')
- # b = tf.get_variable
- # print (w)
- # print (b)
- # saver = tf.train.Saver([conv1/weights])
- # print (saver)
- # vars = tf.train.list_variables("/tmp/mj/cifar10_train")
- # pprint.pprint(vars)
- # reader = tf.train.load_checkpoint("/tmp/mj/cifar10_train")
- # variable_map = reader.get_variable_to_shape_map()
- # names = (variable_map.keys())
- # result = []
- # for name in names:
- # result.append((name, variable_map[name]))
- # pprint.pprint(result)
- # pprint.pprint(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=''))
- # reader = tf.train.load_checkpoint("/tmp/ckpt/")
- # variable_map = reader.get_variable_to_shape_map()
- # names = (variable_map.keys())
- # result = []
- # for name in names:
- # result.append((name, variable_map[name]))
- # pprint.pprint(result)
- # print (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1'))
- # saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1'))
- # print (saver._CheckpointFilename())
- # ckpt_state = tf.train.get_checkpoint_state("/tmp/mj")
- # print(type(ckpt_state))
- # print(ckpt_state.model_checkpoint_path)
- # print(ckpt_state.all_model_checkpoint_paths)
- # recent_ckpt_job_path = tf.train.latest_checkpoint("/tmp/mj/cifar10_train")
- # print(recent_ckpt_job_path)
- # recent_ckpt_job_path = tf.train.latest_checkpoint("/tmp/mj/cifar10_train")
- # chkp.print_tensors_in_checkpoint_file(recent_ckpt_job_path, all_tensors=True, tensor_name='')
- # recent_ckpt_job_path = tf.train.latest_checkpoint("/tmp/mj/cifar10_train")
- # print (recent_ckpt_job_path)
- # chkp.print_tensors_in_checkpoint_file("/tmp/mj/cifar10_train/model.ckpt-0", all_tensors=True, tensor_name='')
- # var_lists = chkp.print_tensors_in_checkpoint_file("/tmp/mj/cifar10_train/model.ckpt-400", all_tensors=True, tensor_name='')
- # print ((var_lists))
- # with tf.Session() as sess:
- # saver.restore(sess, "/tmp/mj/cifar10_train/model.ckpt-400")
- # print ("Model restored.")
- # saver.restore(sess, "/tmp/mj/cifar10_train/conv1/model.ckpt-100")
- # tvars = tf.tf.global_variables_initializer()
- # tvars_vals = sess.run(tvars)
- # for var, val in zip(tvars, tvars_vals):
- # # print(var.name, val) # Prints the name of the variable alongside its value.
- # print(var.name) # Prints the name of the variable alongside its value.
- # reader = tf.train.NewCheckpointReader("/tmp/mj/cifar10_train/model.ckpt-0")
- # print (reader)
- # variables = reader.get_variable_to_shape_map()
- # print (variables)
- # # for ele in variables:
- # # print (ele)
- # print ("\n")
- # print (sess.run(w.shape))
- # sess.run()
- # print (w)
- # print (saver)
- # sess.run(w)
- # print ([n.name for n in tf.get_default_graph().as_graph_def().node])
- # tvars = tf.trainable_variables()
- # tvars_vals = sess.run(tvars)
- # for var, val in zip(tvars, tvars_vals):
- # print(var.name, val) # Prints the name of the variable alongside its value.
- # saver = tf.train.Saver(tf.get_variable('conv1')) # Create a saver.
- # print (saver)
- # saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1'))
- # checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
- # saver.save(sess, checkpoint_path, global_step=step)
- # if step == 100:
- # break
- def main(argv=None): # pylint: disable=unused-argument
- print("train directory is %s" % (FLAGS.train_dir))
- cifar10.maybe_download_and_extract()
- train()
- if __name__ == '__main__':
- tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement