Advertisement
Guest User

Untitled

a guest
Feb 22nd, 2017
189
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.19 KB | None | 0 0
  1. import argparse
  2. import sys
  3.  
  4. import tensorflow as tf
  5. layers = tf.contrib.layers
  6. learn = tf.contrib.learn
  7.  
  8. tf.logging.set_verbosity(tf.logging.INFO)
  9.  
  10.  
  11. FLAGS = None
  12.  
  13.  
  14. def conv_model(feature, target, mode):
  15.   """2-layer convolution model."""
  16.   # Convert the target to a one-hot tensor of shape (batch_size, 10) and
  17.   # with a on-value of 1 for each one-hot vector of length 10.
  18.   target = tf.one_hot(tf.cast(target, tf.int32), 10, 1, 0)
  19.  
  20.   # Reshape feature to 4d tensor with 2nd and 3rd dimensions being
  21.   # image width and height final dimension being the number of color channels.
  22.   feature = tf.reshape(feature, [-1, 28, 28, 1])
  23.  
  24.   # First conv layer will compute 32 features for each 5x5 patch
  25.   with tf.variable_scope('conv_layer1'):
  26.     h_conv1 = layers.conv2d(
  27.         feature, 32, kernel_size=[5, 5], activation_fn=tf.nn.relu)
  28.     h_pool1 = tf.nn.max_pool(
  29.         h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  30.  
  31.   # Second conv layer will compute 64 features for each 5x5 patch.
  32.   with tf.variable_scope('conv_layer2'):
  33.     h_conv2 = layers.conv2d(
  34.         h_pool1, 64, kernel_size=[5, 5], activation_fn=tf.nn.relu)
  35.     h_pool2 = tf.nn.max_pool(
  36.         h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  37.     # reshape tensor into a batch of vectors
  38.     h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
  39.  
  40.   # Densely connected layer with 1024 neurons.
  41.   h_fc1 = layers.dropout(
  42.       layers.fully_connected(
  43.           h_pool2_flat, 1024, activation_fn=tf.nn.relu),
  44.       keep_prob=0.5,
  45.       is_training=mode == tf.contrib.learn.ModeKeys.TRAIN)
  46.  
  47.   # Compute logits (1 per class) and compute loss.
  48.   logits = layers.fully_connected(h_fc1, 10, activation_fn=None)
  49.   loss = tf.losses.softmax_cross_entropy(target, logits)
  50.  
  51.   return tf.argmax(logits, 1), loss
  52.  
  53.  
  54. def main(_):
  55.   ps_hosts = FLAGS.ps_hosts.split(",")
  56.   worker_hosts = FLAGS.worker_hosts.split(",")
  57.  
  58.   # Create a cluster from the parameter server and worker hosts.
  59.   cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
  60.  
  61.   # Create and start a server for the local task.
  62.   server = tf.train.Server(cluster,
  63.                            job_name=FLAGS.job_name,
  64.                            task_index=FLAGS.task_index)
  65.  
  66.   if FLAGS.job_name == "ps":
  67.     server.join()
  68.   elif FLAGS.job_name == "worker":
  69.     ### Download and load MNIST dataset.
  70.     mnist = learn.datasets.load_dataset('mnist')
  71.  
  72.     # Assigns ops to the local worker by default.
  73.     with tf.device(tf.train.replica_device_setter(
  74.         worker_device="/job:worker/task:%d" % FLAGS.task_index,
  75.         cluster=cluster)):
  76.  
  77.       # Build model...
  78.       with tf.name_scope('input'):
  79.         image = tf.placeholder(tf.float32, [None, 784], name='image')
  80.         label = tf.placeholder(tf.float32, [None], name='label')
  81.  
  82.       predict, loss = conv_model(image, label, tf.contrib.learn.ModeKeys.TRAIN)
  83.  
  84.       opt = tf.train.RMSPropOptimizer(0.01)
  85.       opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=len(worker_hosts),
  86.                                            total_num_replicas=len(worker_hosts))
  87.  
  88.       global_step = tf.contrib.framework.get_or_create_global_step()
  89.       train_op = opt.minimize(loss, global_step=global_step)
  90.  
  91.       # The StopAtStepHook handles stopping after running given steps.
  92.       hooks=[tf.train.StopAtStepHook(last_step=1000000),
  93.              tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss}, every_n_iter=100),
  94.              opt.make_session_run_hook(is_chief=(FLAGS.task_index == 0))]
  95.  
  96.       # The MonitoredTrainingSession takes care of session initialization,
  97.       # restoring from a checkpoint, saving to a checkpoint, and closing when done
  98.       # or an error occurs.
  99.       with tf.train.MonitoredTrainingSession(master=server.target,
  100.                                              is_chief=(FLAGS.task_index == 0),
  101.                                              checkpoint_dir="/tmp/train_logs",
  102.                                              hooks=hooks,
  103.                                              config=tf.ConfigProto(log_device_placement=True)) as mon_sess:
  104.         while not mon_sess.should_stop():
  105.           # Run a training step synchronously.
  106.           image_, label_ = mnist.train.next_batch(100)
  107.           mon_sess.run(train_op, feed_dict={image: image_, label: label_})
  108.  
  109.  
  110. if __name__ == "__main__":
  111.   parser = argparse.ArgumentParser()
  112.   parser.register("type", "bool", lambda v: v.lower() == "true")
  113.   # Flags for defining the tf.train.ClusterSpec
  114.   parser.add_argument(
  115.       "--ps_hosts",
  116.       type=str,
  117.       default="",
  118.       help="Comma-separated list of hostname:port pairs"
  119.   )
  120.   parser.add_argument(
  121.       "--worker_hosts",
  122.       type=str,
  123.       default="",
  124.       help="Comma-separated list of hostname:port pairs"
  125.   )
  126.   parser.add_argument(
  127.       "--job_name",
  128.       type=str,
  129.       default="",
  130.       help="One of 'ps', 'worker'"
  131.   )
  132.   # Flags for defining the tf.train.Server
  133.   parser.add_argument(
  134.       "--task_index",
  135.       type=int,
  136.       default=0,
  137.       help="Index of task within the job"
  138.   )
  139.   FLAGS, unparsed = parser.parse_known_args()
  140.   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement