Guest User

Untitled

a guest
Jan 21st, 2019
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.53 KB | None | 0 0
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4.  
  5. import math
  6. import os
  7. import sys
  8. import time
  9.  
  10. import tensorflow as tf
  11. from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
  12.  
  13. NUM_CLASSES = 10
  14. IMAGE_SIZE = 28
  15. IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
  16.  
  17. def inference(images, hidden1_units, hidden2_units):
  18. """Build the MNIST model up to where it may be used for inference.
  19. Args:
  20. images: Images placeholder, from inputs().
  21. hidden1_units: Size of the first hidden layer.
  22. hidden2_units: Size of the second hidden layer.
  23. Returns:
  24. softmax_linear: Output tensor with the computed logits.
  25. """
  26. # Hidden 1
  27. with tf.name_scope('hidden1'):
  28. weights = tf.Variable(
  29. tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
  30. stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
  31. name='weights')
  32. biases = tf.Variable(tf.zeros([hidden1_units]),
  33. name='biases')
  34. hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
  35. # Hidden 2
  36. with tf.name_scope('hidden2'):
  37. weights = tf.Variable(
  38. tf.truncated_normal([hidden1_units, hidden2_units],
  39. stddev=1.0 / math.sqrt(float(hidden1_units))),
  40. name='weights')
  41. biases = tf.Variable(tf.zeros([hidden2_units]),
  42. name='biases')
  43. hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
  44. # Linear
  45. with tf.name_scope('softmax_linear'):
  46. weights = tf.Variable(
  47. tf.truncated_normal([hidden2_units, NUM_CLASSES],
  48. stddev=1.0 / math.sqrt(float(hidden2_units))),
  49. name='weights')
  50. biases = tf.Variable(tf.zeros([NUM_CLASSES]),
  51. name='biases')
  52. logits = tf.matmul(hidden2, weights) + biases
  53. return logits
  54.  
  55.  
  56. def loss(logits, labels):
  57. labels = tf.to_int64(labels)
  58. return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  59.  
  60.  
  61. def training(loss, learning_rate):
  62. # Add a scalar summary for the snapshot loss.
  63. tf.summary.scalar('loss', loss)
  64. # Create the gradient descent optimizer with the given learning rate.
  65. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  66. # Create a variable to track the global step.
  67. global_step = tf.Variable(0, name='global_step', trainable=False)
  68. # Use the optimizer to apply the gradients that minimize the loss
  69. # (and also increment the global step counter) as a single training step.
  70. train_op = optimizer.minimize(loss, global_step=global_step)
  71. return train_op
  72.  
  73.  
  74. def evaluation(logits, labels, name):
  75. correct = tf.nn.in_top_k(logits, labels, 1)
  76. return tf.reduce_sum(tf.cast(correct, tf.int32), name=name)
  77.  
  78.  
  79. input_data_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/input_data')
  80. data_sets = read_data_sets(input_data_dir, False)
  81. log_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/logs/fully_connected_feed'),
  82. batch_size = 100
  83. learning_rate = 0.01
  84. hidden1 = 128
  85. hidden2 = 32
  86. max_steps = 2000
  87. checkpoint_file = os.path.join(os.getenv('HOME'), 'Tmp/tf_converter/model.ckpt')
  88.  
  89. ### Utils
  90.  
  91. def placeholder_inputs(batch_size):
  92. images_placeholder = tf.placeholder(
  93. tf.float32, shape=(batch_size, IMAGE_PIXELS), name='images_placeholder')
  94. labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size), name='labels_placeholder')
  95. return images_placeholder, labels_placeholder
  96.  
  97. def fill_feed_dict(data_set, images_pl, labels_pl):
  98. images_feed, labels_feed = data_set.next_batch(batch_size, False)
  99. feed_dict = {
  100. images_pl: images_feed,
  101. labels_pl: labels_feed,
  102. }
  103. return feed_dict
  104.  
  105. def do_eval(sess,
  106. eval_correct,
  107. images_placeholder,
  108. labels_placeholder,
  109. data_set):
  110. # And run one epoch of eval.
  111. true_count = 0 # Counts the number of correct predictions.
  112. steps_per_epoch = data_set.num_examples // batch_size
  113. num_examples = steps_per_epoch * batch_size
  114.  
  115. for step in xrange(steps_per_epoch):
  116. feed_dict = fill_feed_dict(data_set,
  117. images_placeholder,
  118. labels_placeholder)
  119. true_count += sess.run(eval_correct, feed_dict=feed_dict)
  120. precision = float(true_count) / num_examples
  121. print('Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
  122. (num_examples, true_count, precision))
  123.  
  124.  
  125. ## Start training
  126. """
  127. with tf.Graph().as_default():
  128. # summary = tf.summary.merge_all()
  129. # summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
  130. sess = tf.Session()
  131.  
  132. images_placeholder, labels_placeholder = placeholder_inputs(batch_size)
  133. logits = inference(images_placeholder, 128, 32)
  134. loss = loss(logits, labels_placeholder)
  135. train_op = training(loss, learning_rate)
  136. eval_correct = evaluation(logits, labels_placeholder, 'eval_correct')
  137. saver = tf.train.Saver()
  138.  
  139. init = tf.global_variables_initializer()
  140. sess.run(init)
  141.  
  142. for step in xrange(max_steps):
  143. start_time = time.time()
  144.  
  145. feed_dict = fill_feed_dict(
  146. data_sets.train, images_placeholder, labels_placeholder)
  147. _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
  148.  
  149. duration = time.time() - start_time
  150.  
  151. if step % 100 == 0:
  152. print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
  153. # Update the events file.
  154. # summary_str = sess.run(summary, feed_dict=feed_dict)
  155. # summary_writer.add_summary(summary_str, step)
  156. # summary_writer.flush()
  157.  
  158. if (step + 1) % 1000 == 0 or (step + 1) == max_steps:
  159. # saver.save(sess, checkpoint_file, global_step=step)
  160. print('Test Data Eval:')
  161. do_eval(sess, eval_correct, images_placeholder,
  162. labels_placeholder, data_sets.test)
  163.  
  164. saver.save(sess, checkpoint_file)
  165. """
  166.  
  167. # Load and save
  168. meta_file = checkpoint_file + '.meta'
  169.  
  170. with tf.Graph().as_default():
  171. sess = tf.Session()
  172. new_saver = tf.train.import_meta_graph(meta_file)
  173. # g = new_saver.export_meta_graph()
  174. # print(g) # python mnist.py > fuck.txt
  175. # g.SerializeToString()
  176. # g.ListFields()
  177. new_saver.restore(sess, checkpoint_file)
  178. graph = tf.get_default_graph()
  179. # graph.get_operations()
  180.  
  181. images_placeholder = graph.get_tensor_by_name('images_placeholder:0')
  182. labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
  183. eval_correct = graph.get_tensor_by_name('eval_correct:0')
  184.  
  185. print('Test Data Eval after restoring model:')
  186. do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)
  187.  
  188.  
  189. #====
  190. #from tensorflow.python.tools import inspect_checkpoint as chkp
  191. #chkp.print_tensors_in_checkpoint_file("model.ckpt", tensor_name='', all_tensors=True)
Add Comment
Please, Sign In to add comment