Guest User

Untitled

a guest
Feb 18th, 2019
303
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.25 KB | None | 0 0
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15.  
  16. """A deep MNIST classifier using convolutional layers.
  17.  
  18. This example was adapted from
  19. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist_deep.py.
  20.  
  21. Each worker reads the full MNIST dataset and asynchronously trains a CNN with dropout and using the Adam optimizer,
  22. updating the model parameters on shared parameter servers.
  23.  
  24. The current training accuracy is printed out after every 100 steps.
  25. """
  26.  
  27. from __future__ import absolute_import
  28. from __future__ import division
  29. from __future__ import print_function
  30.  
  31. from tensorflow.examples.tutorials.mnist import input_data
  32. from threading import Thread
  33.  
  34. import json
  35. import logging
  36. import os
  37. import sys
  38. import tensorflow as tf
  39.  
  40.  
  41. # Input/output directories
  42. tf.flags.DEFINE_string('data_dir', '/tmp/tensorflow/mnist/input_data',
  43. 'Directory for storing input data')
  44. tf.flags.DEFINE_string('working_dir', '/tmp/tensorflow/mnist/working_dir',
  45. 'Directory under which events and output will be stored (in separate subdirectories).')
  46.  
  47. # Training parameters
  48. tf.flags.DEFINE_integer("steps", 1500, "The number of training steps to execute.")
  49. tf.flags.DEFINE_integer("batch_size", 64, "The batch size per step.")
  50.  
  51. FLAGS = tf.flags.FLAGS
  52.  
  53.  
  54. def deepnn(x):
  55. """deepnn builds the graph for a deep net for classifying digits.
  56.  
  57. Args:
  58. x: an input tensor with the dimensions (N_examples, 784), where 784 is the
  59. number of pixels in a standard MNIST image.
  60.  
  61. Returns:
  62. A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values
  63. equal to the logits of classifying the digit into one of 10 classes (the
  64. digits 0-9). keep_prob is a scalar placeholder for the probability of
  65. dropout.
  66. """
  67. # Reshape to use within a convolutional neural net.
  68. # Last dimension is for "features" - there is only one here, since images are
  69. # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
  70. with tf.name_scope('reshape'):
  71. x_image = tf.reshape(x, [-1, 28, 28, 1])
  72.  
  73. # First convolutional layer - maps one grayscale image to 32 feature maps.
  74. with tf.name_scope('conv1'):
  75. W_conv1 = weight_variable([5, 5, 1, 32])
  76. b_conv1 = bias_variable([32])
  77. h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
  78.  
  79. # Pooling layer - downsamples by 2X.
  80. with tf.name_scope('pool1'):
  81. h_pool1 = max_pool_2x2(h_conv1)
  82.  
  83. # Second convolutional layer -- maps 32 feature maps to 64.
  84. with tf.name_scope('conv2'):
  85. W_conv2 = weight_variable([5, 5, 32, 64])
  86. b_conv2 = bias_variable([64])
  87. h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
  88.  
  89. # Second pooling layer.
  90. with tf.name_scope('pool2'):
  91. h_pool2 = max_pool_2x2(h_conv2)
  92.  
  93. # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image
  94. # is down to 7x7x64 feature maps -- maps this to 1024 features.
  95. with tf.name_scope('fc1'):
  96. W_fc1 = weight_variable([7 * 7 * 64, 1024])
  97. b_fc1 = bias_variable([1024])
  98.  
  99. h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
  100. h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
  101.  
  102. # Dropout - controls the complexity of the model, prevents co-adaptation of
  103. # features.
  104. with tf.name_scope('dropout'):
  105. keep_prob = tf.placeholder(tf.float32)
  106. h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
  107.  
  108. # Map the 1024 features to 10 classes, one for each digit
  109. with tf.name_scope('fc2'):
  110. W_fc2 = weight_variable([1024, 10])
  111. b_fc2 = bias_variable([10])
  112.  
  113. y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
  114. return y_conv, keep_prob
  115.  
  116.  
  117. def conv2d(x, W):
  118. """conv2d returns a 2d convolution layer with full stride."""
  119. return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
  120.  
  121.  
  122. def max_pool_2x2(x):
  123. """max_pool_2x2 downsamples a feature map by 2X."""
  124. return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
  125. strides=[1, 2, 2, 1], padding='SAME')
  126.  
  127.  
  128. def weight_variable(shape):
  129. """weight_variable generates a weight variable of a given shape."""
  130. initial = tf.truncated_normal(shape, stddev=0.1)
  131. return tf.Variable(initial)
  132.  
  133.  
  134. def bias_variable(shape):
  135. """bias_variable generates a bias variable of a given shape."""
  136. initial = tf.constant(0.1, shape=shape)
  137. return tf.Variable(initial)
  138.  
  139.  
  140. def create_model():
  141. """Creates our model and returns the target nodes to be run or populated"""
  142. # Create the model
  143. x = tf.placeholder(tf.float32, [None, 784])
  144.  
  145. # Define loss and optimizer
  146. y_ = tf.placeholder(tf.int64, [None])
  147.  
  148. # Build the graph for the deep net
  149. y_conv, keep_prob = deepnn(x)
  150.  
  151. with tf.name_scope('loss'):
  152. cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y_conv)
  153. cross_entropy = tf.reduce_mean(cross_entropy)
  154.  
  155. global_step = tf.train.get_or_create_global_step()
  156. with tf.name_scope('adam_optimizer'):
  157. train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy, global_step=global_step)
  158.  
  159. with tf.name_scope('accuracy'):
  160. correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
  161. correct_prediction = tf.cast(correct_prediction, tf.float32)
  162. accuracy = tf.reduce_mean(correct_prediction)
  163.  
  164. tf.summary.scalar('cross_entropy_loss', cross_entropy)
  165. tf.summary.scalar('accuracy', accuracy)
  166.  
  167. merged = tf.summary.merge_all()
  168.  
  169. return x, y_, keep_prob, global_step, train_step, accuracy, merged
  170.  
  171.  
  172.  
  173. def main(_):
  174. logging.getLogger().setLevel(logging.INFO)
  175.  
  176. cluster_spec_str = os.environ["CLUSTER_SPEC"]
  177. cluster_spec = json.loads(cluster_spec_str)
  178. ps_hosts = cluster_spec['ps']
  179. worker_hosts = cluster_spec['worker']
  180.  
  181. # Create a cluster from the parameter server and worker hosts.
  182. cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
  183.  
  184. # Create and start a server for the local task.
  185. job_name = os.environ["JOB_NAME"]
  186. task_index = int(os.environ["TASK_INDEX"])
  187. server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
  188.  
  189. if job_name == "ps":
  190. server.join()
  191. elif job_name == "worker":
  192. # Create our model graph. Assigns ops to the local worker by default.
  193. with tf.device(tf.train.replica_device_setter(
  194. worker_device="/job:worker/task:%d" % task_index,
  195. cluster=cluster)):
  196. features, labels, keep_prob, global_step, train_step, accuracy, merged = create_model()
  197.  
  198. if task_index is 0: # chief worker
  199. tf.gfile.MakeDirs(FLAGS.working_dir)
  200.  
  201. # The StopAtStepHook handles stopping after running given steps.
  202. hooks = [tf.train.StopAtStepHook(num_steps=FLAGS.steps)]
  203.  
  204. # Filter all connections except that between ps and this worker to avoid hanging issues when
  205. # one worker finishes. We are using asynchronous training so there is no need for the workers to communicate.
  206. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
  207. config_proto = tf.ConfigProto(gpu_options=gpu_options, device_filters = ['/job:ps', '/job:worker/task:%d' % task_index])
  208.  
  209. with tf.train.MonitoredTrainingSession(master=server.target,
  210. is_chief=(task_index == 0),
  211. checkpoint_dir=FLAGS.working_dir,
  212. hooks=hooks,
  213. config=config_proto) as sess:
  214. # Import data
  215. logging.info('Extracting and loading input data...')
  216. mnist = input_data.read_data_sets(FLAGS.data_dir)
  217.  
  218. # Train
  219. logging.info('Starting training')
  220. i = 0
  221. while not sess.should_stop():
  222. batch = mnist.train.next_batch(FLAGS.batch_size)
  223. if i % 100 == 0:
  224. step, _, train_accuracy = sess.run([global_step, train_step, accuracy],
  225. feed_dict={features: batch[0], labels: batch[1], keep_prob: 1.0})
  226. logging.info('Step %d, training accuracy: %g' % (step, train_accuracy))
  227. else:
  228. sess.run([global_step, train_step],
  229. feed_dict={features: batch[0], labels: batch[1], keep_prob: 0.5})
  230. i += 1
  231.  
  232. logging.info('Done training!')
  233. sys.exit()
  234.  
  235.  
  236. if __name__ == '__main__':
  237. tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment