Guest User

Untitled

a guest
Oct 17th, 2017
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.48 KB | None | 0 0
  1. import argparse
  2. import sys
  3. import time
  4.  
  5. import numpy as np
  6. import tensorflow as tf
  7. from tensorflow.python.tools import optimize_for_inference_lib
  8. from tensorflow.examples.tutorials.mnist import input_data
  9.  
  10.  
  11. FLAGS = None
  12.  
  13. NUM_ITERS = 20
  14. BATCH_SIZE = 4
  15. MNIST_X = 28
  16. MNIST_Y = 28
  17. MNIST_CHANNELS = 1
  18. MNIST_CLASSES = 10
  19.  
  20. # Up-scale the MNIST input to these sizes
  21. INPUT_X = 784
  22. INPUT_Y = 392
  23. INPUT_CHANNELS = 3
  24.  
  25. # Network size settings
  26. NUM_FEATURE_MAPS = 64
  27.  
  28.  
  29. def conv(conv_input, num_channels, name):
  30. with tf.variable_scope("conv_" + name):
  31. conv = tf.contrib.layers.convolution2d(conv_input, num_channels,
  32. activation_fn=tf.nn.relu,
  33. kernel_size=(3, 3), stride=1, padding="SAME")
  34. print(conv)
  35. return conv
  36.  
  37.  
  38. def model(network_input):
  39. conv1 = conv(network_input, NUM_FEATURE_MAPS, "conv1")
  40. conv2 = conv(conv1, NUM_FEATURE_MAPS, "conv2")
  41. conv3 = conv(conv2, NUM_FEATURE_MAPS, "conv3")
  42. result = conv(conv2, MNIST_CLASSES, "conv4")
  43. return result
  44.  
  45.  
  46. def main(_):
  47.  
  48. print("Loading MNIST data")
  49. mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
  50.  
  51. print("Creating model")
  52. x = tf.placeholder(tf.float32, [None, INPUT_X, INPUT_Y, INPUT_CHANNELS])
  53. x = tf.identity(x, "input")
  54. y = model(x)
  55. y = tf.identity(y, "output")
  56. y_ = tf.placeholder(tf.float32, [None, INPUT_X, INPUT_Y, MNIST_CLASSES])
  57.  
  58. print("Defining loss")
  59. cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
  60. train_step = tf.train.GradientDescentOptimizer(0.005).minimize(cross_entropy)
  61.  
  62. print("Starting session")
  63. print("Batch size: %d, num_steps: %d" % (BATCH_SIZE, NUM_ITERS))
  64. config = tf.ConfigProto()
  65. with tf.Session(config=config) as sess:
  66. tf.global_variables_initializer().run()
  67.  
  68. print("Starting training")
  69. for i in range(NUM_ITERS):
  70. batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)
  71. batch_xs = np.reshape(batch_xs, newshape=[BATCH_SIZE, MNIST_X, MNIST_Y, MNIST_CHANNELS])
  72.  
  73. # Just to make the MNIST data a bit larger to reflect more realistic image sizes
  74. batch_xs = np.repeat(batch_xs, INPUT_X / MNIST_X, axis=1)
  75. batch_xs = np.repeat(batch_xs, INPUT_Y / MNIST_Y, axis=2)
  76. batch_xs = np.repeat(batch_xs, INPUT_CHANNELS / MNIST_CHANNELS, axis=3)
  77.  
  78. # Also makes the output larger: actually making it into a 2D label of the same size as the input
  79. batch_ys = np.repeat(batch_ys[:, np.newaxis, :], INPUT_X, axis=1)
  80. batch_ys = np.repeat(batch_ys[:, :, np.newaxis, :], INPUT_Y, axis=2)
  81.  
  82. sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  83. print("Done training")
  84.  
  85. print("Output graph to disk")
  86. graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
  87. graph = optimize_for_inference_lib.optimize_for_inference(graph, ["input"], ["output"],
  88. placeholder_type_enum=tf.float32.as_datatype_enum)
  89. tf.train.write_graph(graph, ".", "graph.pb", as_text=False)
  90.  
  91.  
  92. if __name__ == '__main__':
  93. parser = argparse.ArgumentParser()
  94. parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
  95. help='Directory for storing input data')
  96. FLAGS, unparsed = parser.parse_known_args()
  97. tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Add Comment
Please, Sign In to add comment