SHARE
TWEET

Untitled

a guest Jan 21st, 2019 73 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4.  
  5. import math
  6. import os
  7. import sys
  8. import time
  9.  
  10. import tensorflow as tf
  11. from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
  12.  
  13. NUM_CLASSES = 10
  14. IMAGE_SIZE = 28
  15. IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
  16.  
  17. def inference(images, hidden1_units, hidden2_units):
  18.   """Build the MNIST model up to where it may be used for inference.
  19.   Args:
  20.     images: Images placeholder, from inputs().
  21.     hidden1_units: Size of the first hidden layer.
  22.     hidden2_units: Size of the second hidden layer.
  23.   Returns:
  24.     softmax_linear: Output tensor with the computed logits.
  25.   """
  26.   # Hidden 1
  27.   with tf.name_scope('hidden1'):
  28.     weights = tf.Variable(
  29.         tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
  30.                             stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
  31.         name='weights')
  32.     biases = tf.Variable(tf.zeros([hidden1_units]),
  33.                          name='biases')
  34.     hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
  35.   # Hidden 2
  36.   with tf.name_scope('hidden2'):
  37.     weights = tf.Variable(
  38.         tf.truncated_normal([hidden1_units, hidden2_units],
  39.                             stddev=1.0 / math.sqrt(float(hidden1_units))),
  40.         name='weights')
  41.     biases = tf.Variable(tf.zeros([hidden2_units]),
  42.                          name='biases')
  43.     hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
  44.   # Linear
  45.   with tf.name_scope('softmax_linear'):
  46.     weights = tf.Variable(
  47.         tf.truncated_normal([hidden2_units, NUM_CLASSES],
  48.                             stddev=1.0 / math.sqrt(float(hidden2_units))),
  49.         name='weights')
  50.     biases = tf.Variable(tf.zeros([NUM_CLASSES]),
  51.                          name='biases')
  52.     logits = tf.matmul(hidden2, weights) + biases
  53.   return logits
  54.  
  55.  
  56. def loss(logits, labels):
  57.   labels = tf.to_int64(labels)
  58.   return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  59.  
  60.  
  61. def training(loss, learning_rate):
  62.   # Add a scalar summary for the snapshot loss.
  63.   tf.summary.scalar('loss', loss)
  64.   # Create the gradient descent optimizer with the given learning rate.
  65.   optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  66.   # Create a variable to track the global step.
  67.   global_step = tf.Variable(0, name='global_step', trainable=False)
  68.   # Use the optimizer to apply the gradients that minimize the loss
  69.   # (and also increment the global step counter) as a single training step.
  70.   train_op = optimizer.minimize(loss, global_step=global_step)
  71.   return train_op
  72.  
  73.  
  74. def evaluation(logits, labels, name):
  75.   correct = tf.nn.in_top_k(logits, labels, 1)
  76.   return tf.reduce_sum(tf.cast(correct, tf.int32), name=name)
  77.  
  78.  
  79. input_data_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/input_data')
  80. data_sets = read_data_sets(input_data_dir, False)
  81. log_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/logs/fully_connected_feed'),
  82. batch_size = 100
  83. learning_rate = 0.01
  84. hidden1 = 128
  85. hidden2 = 32
  86. max_steps = 2000
  87. checkpoint_file = os.path.join(os.getenv('HOME'), 'Tmp/tf_converter/model.ckpt')
  88.  
  89. ### Utils
  90.  
  91. def placeholder_inputs(batch_size):
  92.   images_placeholder = tf.placeholder(
  93.     tf.float32, shape=(batch_size, IMAGE_PIXELS), name='images_placeholder')
  94.   labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size), name='labels_placeholder')
  95.   return images_placeholder, labels_placeholder
  96.  
  97. def fill_feed_dict(data_set, images_pl, labels_pl):
  98.   images_feed, labels_feed = data_set.next_batch(batch_size, False)
  99.   feed_dict = {
  100.       images_pl: images_feed,
  101.       labels_pl: labels_feed,
  102.   }
  103.   return feed_dict
  104.  
  105. def do_eval(sess,
  106.             eval_correct,
  107.             images_placeholder,
  108.             labels_placeholder,
  109.             data_set):
  110.   # And run one epoch of eval.
  111.   true_count = 0  # Counts the number of correct predictions.
  112.   steps_per_epoch = data_set.num_examples // batch_size
  113.   num_examples = steps_per_epoch * batch_size
  114.  
  115.   for step in xrange(steps_per_epoch):
  116.     feed_dict = fill_feed_dict(data_set,
  117.                                images_placeholder,
  118.                                labels_placeholder)
  119.     true_count += sess.run(eval_correct, feed_dict=feed_dict)
  120.   precision = float(true_count) / num_examples
  121.   print('Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
  122.         (num_examples, true_count, precision))
  123.  
  124.  
  125. ## Start training
  126. """
  127. with tf.Graph().as_default():
  128.   # summary = tf.summary.merge_all()
  129.   # summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
  130.   sess = tf.Session()
  131.  
  132.   images_placeholder, labels_placeholder = placeholder_inputs(batch_size)
  133.   logits = inference(images_placeholder, 128, 32)
  134.   loss = loss(logits, labels_placeholder)
  135.   train_op = training(loss, learning_rate)
  136.   eval_correct = evaluation(logits, labels_placeholder, 'eval_correct')
  137.   saver = tf.train.Saver()
  138.  
  139.   init = tf.global_variables_initializer()
  140.   sess.run(init)
  141.  
  142.   for step in xrange(max_steps):
  143.     start_time = time.time()
  144.  
  145.     feed_dict = fill_feed_dict(
  146.       data_sets.train, images_placeholder, labels_placeholder)
  147.     _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
  148.  
  149.     duration = time.time() - start_time
  150.  
  151.     if step % 100 == 0:
  152.       print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
  153.       # Update the events file.
  154.       # summary_str = sess.run(summary, feed_dict=feed_dict)
  155.       # summary_writer.add_summary(summary_str, step)
  156.       # summary_writer.flush()
  157.    
  158.     if (step + 1) % 1000 == 0 or (step + 1) == max_steps:
  159.       # saver.save(sess, checkpoint_file, global_step=step)
  160.       print('Test Data Eval:')
  161.       do_eval(sess, eval_correct, images_placeholder,
  162.         labels_placeholder, data_sets.test)
  163.  
  164.   saver.save(sess, checkpoint_file)
  165. """
  166.  
  167. # Load and save
  168. meta_file = checkpoint_file + '.meta'
  169.  
  170. with tf.Graph().as_default():
  171.   sess = tf.Session()
  172.   new_saver = tf.train.import_meta_graph(meta_file)
  173.   # g = new_saver.export_meta_graph()
  174.   # print(g) # python mnist.py > fuck.txt
  175.   # g.SerializeToString()
  176.   # g.ListFields()
  177.   new_saver.restore(sess, checkpoint_file)
  178.   graph = tf.get_default_graph()
  179.   # graph.get_operations()
  180.  
  181.   images_placeholder = graph.get_tensor_by_name('images_placeholder:0')
  182.   labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
  183.   eval_correct = graph.get_tensor_by_name('eval_correct:0')
  184.  
  185.   print('Test Data Eval after restoring model:')
  186.   do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)
  187.  
  188.  
  189. #====
  190. #from tensorflow.python.tools import inspect_checkpoint as chkp
  191. #chkp.print_tensors_in_checkpoint_file("model.ckpt", tensor_name='', all_tensors=True)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top