Advertisement
Guest User

Untitled

a guest
Apr 15th, 2019
122
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.17 KB | None | 0 0
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. #     http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15.  
  16. """A binary to train CIFAR-10 using a single GPU.
  17.  
  18. Accuracy:
  19. cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
  20. data) as judged by cifar10_eval.py.
  21.  
  22. Speed: With batch_size 128.
  23.  
  24. System        | Step Time (sec/batch)  |     Accuracy
  25. ------------------------------------------------------------------
  26. 1 Tesla K20m  | 0.35-0.60              | ~86% at 60K steps  (5 hours)
  27. 1 Tesla K40m  | 0.25-0.35              | ~86% at 100K steps (4 hours)
  28.  
  29. Usage:
  30. Please see the tutorial and website for how to download the CIFAR-10
  31. data set, compile the program and train the model.
  32.  
  33. http://tensorflow.org/tutorials/deep_cnn/
  34. """
  35. from __future__ import absolute_import
  36. from __future__ import division
  37. from __future__ import print_function
  38.  
  39. from datetime import datetime
  40. import os.path
  41. import time
  42.  
  43. import pycurl
  44. import re
  45. from StringIO import StringIO
  46. import tinys3
  47. import os
  48. import sys
  49. import random, string
  50. import numpy as np
  51. from six.moves import xrange  # pylint: disable=redefined-builtin
  52. import tensorflow as tf
  53. import pprint
  54. import cifar10
  55. import pprint
  56.  
  57. from tensorflow.python.tools import inspect_checkpoint as chkp
  58.  
  59.  
  60. interrupt_check_url = "http://169.254.169.254/latest/meta-data/spot/termination-time"
  61.  
  62. FLAGS = tf.app.flags.FLAGS
  63.  
  64. tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
  65.                            """Directory where to write event logs """
  66.                            """and checkpoint.""")
  67. tf.app.flags.DEFINE_integer('max_steps', sys.maxint,
  68.                             """Number of batches to run.""")
  69. tf.app.flags.DEFINE_boolean('log_device_placement', False,
  70.                             """Whether to log device placement.""")
  71. tf.app.flags.DEFINE_string('checkpoint_dir', None,
  72.                            """Checkpoint file path to start training""")
  73.  
  74.  
  75. def train():
  76.     """Train CIFAR-10 for a number of steps."""
  77.     with tf.Graph().as_default():
  78.         global_step = tf.Variable(0, trainable=False)
  79.  
  80.         images, labels = cifar10.distorted_inputs() # Get images and labels for CIFAR-10.
  81.  
  82.         logits = cifar10.inference(images) # Build a Graph that computes the logits predictions from the
  83.         # inference model.
  84.  
  85.         loss = cifar10.loss(logits, labels) # Calculate loss.
  86.  
  87.         train_op = cifar10.train(loss, global_step) # Build a Graph that trains the model with one batch of examples and updates the model parameters.
  88.         # print (tf.global_variables())
  89.         saver = tf.train.Saver(tf.global_variables()) # Create a saver.
  90.  
  91.         summary_op = tf.summary.merge_all() # Build the summary operation based on the TF collection of Summaries.
  92.  
  93.         init = tf.global_variables_initializer() # Build an initialization operation to run below.
  94.  
  95.         sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
  96.         sess.run(init) # Start running operations on the Graph.
  97.  
  98.         if FLAGS.checkpoint_dir is not None:
  99.             ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
  100.             print("checkpoint path is %s" % ckpt.model_checkpoint_path)
  101.             tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)
  102.  
  103.         # Start the queue runners.
  104.         print("FLAGS.checkpoint_dir is %s" % FLAGS.checkpoint_dir)
  105.         tf.train.start_queue_runners(sess=sess)
  106.         summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
  107.  
  108.         cur_step = sess.run(global_step);
  109.         print("current step is %s" % cur_step)
  110.  
  111.         interrupt_check_duration = 0.0
  112.         elapsed_time = time.time()
  113.         flag = 0
  114.         for step in xrange(cur_step, FLAGS.max_steps):
  115.             _, loss_value = sess.run([train_op, loss])
  116.             assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
  117.  
  118.             # if step % 10 == 0:
  119.             #     print (step)
  120.             #     num_examples_per_step = FLAGS.batch_size
  121.             #     examples_per_sec = num_examples_per_step / duration
  122.             #     sec_per_batch = float(duration)
  123.  
  124.             #     format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
  125.             #                   'sec/batch)')
  126.             #     print(format_str % (datetime.now(), step, loss_value,
  127.             #                         examples_per_sec, sec_per_batch))
  128.  
  129.             if step % 10 == 0:
  130.                 # print (step)
  131.                 # summary_str = sess.run(summary_op)
  132.                 # # print (summary_str)
  133.                 # # break
  134.                 # summary_writer.add_summary(summary_str, step) # Adds a Summary protocol buffer to the event file.
  135.                 # # print (dir(summary_writer))
  136.                 # # print (summary_writer.get_logdir)
  137.                 # # print (sess.__class__)
  138.  
  139.                 # reader = tf.train.NewCheckpointReader("/tmp/ckpt/model.ckpt-70000")
  140.                 # print (reader)
  141.                 # variables = reader.get_variable_to_shape_map()
  142.                 # print (variables)
  143.                 # # for ele in variables:
  144.                 #     # print (ele)
  145.                 # print ("\n")
  146.  
  147.                 saver.restore(sess, "/tmp/ckpt/model.ckpt-60000")
  148.                
  149.          
  150.                 tvars = tf.trainable_variables()
  151.                 tvars_vals = sess.run(tvars)
  152.                 for var, val in zip(tvars, tvars_vals):
  153.                     # print(var.name, val)  # Prints the name of the variable alongside its value.
  154.                     print(var.name)  # Prints the name of the variable alongside its value.
  155.                     print(val)
  156.  
  157.  
  158.                 break
  159.  
  160.  
  161.             # Save the model checkpoint periodically.
  162.             # if step % 100 == 0 or (step + 1) == FLAGS.max_steps:
  163.  
  164.  
  165.                 # with tf.variable_scope('conv1') as scope:
  166.                 #     # print (scope.get_variable())
  167.                 # #     print (scope)
  168.                 #     tf.get_variable_scope().reuse_variables()
  169.                 #     w = tf.get_variable('weights')
  170.                 #     b = tf.get_variable
  171.                 #     print (w)
  172.                 #     print (b)
  173.  
  174.  
  175.                 # saver = tf.train.Saver([conv1/weights])
  176.                 # print (saver)
  177.  
  178.  
  179.                 # vars = tf.train.list_variables("/tmp/mj/cifar10_train")
  180.                 # pprint.pprint(vars)
  181.  
  182.  
  183.                 # reader = tf.train.load_checkpoint("/tmp/mj/cifar10_train")
  184.                 # variable_map = reader.get_variable_to_shape_map()
  185.                 # names = (variable_map.keys())
  186.                 # result = []
  187.                 # for name in names:
  188.                 #     result.append((name, variable_map[name]))
  189.                 # pprint.pprint(result)
  190.  
  191.  
  192.                 # pprint.pprint(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=''))
  193.  
  194.                 # reader = tf.train.load_checkpoint("/tmp/ckpt/")
  195.                 # variable_map = reader.get_variable_to_shape_map()
  196.                 # names = (variable_map.keys())
  197.                 # result = []
  198.                 # for name in names:
  199.                 #     result.append((name, variable_map[name]))
  200.                 # pprint.pprint(result)
  201.  
  202.  
  203.                 # print (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1'))
  204.  
  205.  
  206.                 # saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1'))
  207.                 # print (saver._CheckpointFilename())
  208.  
  209.  
  210.                 # ckpt_state = tf.train.get_checkpoint_state("/tmp/mj")
  211.                 # print(type(ckpt_state))
  212.                 # print(ckpt_state.model_checkpoint_path)
  213.                 # print(ckpt_state.all_model_checkpoint_paths)
  214.  
  215.  
  216.                 # recent_ckpt_job_path = tf.train.latest_checkpoint("/tmp/mj/cifar10_train")
  217.                 # print(recent_ckpt_job_path)
  218.  
  219.  
  220.                 # recent_ckpt_job_path = tf.train.latest_checkpoint("/tmp/mj/cifar10_train")
  221.                 # chkp.print_tensors_in_checkpoint_file(recent_ckpt_job_path, all_tensors=True, tensor_name='')
  222.  
  223.  
  224.                 # recent_ckpt_job_path = tf.train.latest_checkpoint("/tmp/mj/cifar10_train")
  225.                 # print (recent_ckpt_job_path)
  226.  
  227.  
  228.                 # chkp.print_tensors_in_checkpoint_file("/tmp/mj/cifar10_train/model.ckpt-0", all_tensors=True, tensor_name='')
  229.  
  230.  
  231.                 # var_lists = chkp.print_tensors_in_checkpoint_file("/tmp/mj/cifar10_train/model.ckpt-400", all_tensors=True, tensor_name='')
  232.                 # print ((var_lists))
  233.  
  234.  
  235.                 # with tf.Session() as sess:
  236.                 # saver.restore(sess, "/tmp/mj/cifar10_train/model.ckpt-400")
  237.                 # print ("Model restored.")
  238.  
  239.  
  240.  
  241.                 # saver.restore(sess, "/tmp/mj/cifar10_train/conv1/model.ckpt-100")
  242.                 # tvars = tf.tf.global_variables_initializer()
  243.                 # tvars_vals = sess.run(tvars)
  244.                 # for var, val in zip(tvars, tvars_vals):
  245.                 #     # print(var.name, val)  # Prints the name of the variable alongside its value.
  246.                 #     print(var.name)  # Prints the name of the variable alongside its value.
  247.  
  248.  
  249.  
  250.                 # reader = tf.train.NewCheckpointReader("/tmp/mj/cifar10_train/model.ckpt-0")
  251.                 # print (reader)
  252.                 # variables = reader.get_variable_to_shape_map()
  253.                 # print (variables)
  254.                 # # for ele in variables:
  255.                 #     # print (ele)
  256.                 # print ("\n")
  257.  
  258.  
  259.                 # print (sess.run(w.shape))
  260.                 # sess.run()
  261.                 # print (w)
  262.                 # print (saver)
  263.                 # sess.run(w)
  264.  
  265.  
  266.                 # print ([n.name for n in tf.get_default_graph().as_graph_def().node])
  267.  
  268.  
  269.                 # tvars = tf.trainable_variables()
  270.                 # tvars_vals = sess.run(tvars)
  271.                 # for var, val in zip(tvars, tvars_vals):
  272.                 #     print(var.name, val)  # Prints the name of the variable alongside its value.
  273.  
  274.  
  275.                 # saver = tf.train.Saver(tf.get_variable('conv1')) # Create a saver.
  276.                 # print (saver)
  277.  
  278.  
  279.                 # saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1'))
  280.  
  281.  
  282.                 # checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
  283.                 # saver.save(sess, checkpoint_path, global_step=step)
  284.  
  285.  
  286.                 # if step == 100:
  287.                     # break
  288.  
  289.  
  290. def main(argv=None):  # pylint: disable=unused-argument
  291.     print("train directory is %s" % (FLAGS.train_dir))
  292.     cifar10.maybe_download_and_extract()
  293.     train()
  294.  
  295.  
  296. if __name__ == '__main__':
  297.     tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement