Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import sys
- import tensorflow as tf
- layers = tf.contrib.layers
- learn = tf.contrib.learn
- tf.logging.set_verbosity(tf.logging.INFO)
- FLAGS = None
- def conv_model(feature, target, mode):
- """2-layer convolution model."""
- # Convert the target to a one-hot tensor of shape (batch_size, 10) and
- # with a on-value of 1 for each one-hot vector of length 10.
- target = tf.one_hot(tf.cast(target, tf.int32), 10, 1, 0)
- # Reshape feature to 4d tensor with 2nd and 3rd dimensions being
- # image width and height final dimension being the number of color channels.
- feature = tf.reshape(feature, [-1, 28, 28, 1])
- # First conv layer will compute 32 features for each 5x5 patch
- with tf.variable_scope('conv_layer1'):
- h_conv1 = layers.conv2d(
- feature, 32, kernel_size=[5, 5], activation_fn=tf.nn.relu)
- h_pool1 = tf.nn.max_pool(
- h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
- # Second conv layer will compute 64 features for each 5x5 patch.
- with tf.variable_scope('conv_layer2'):
- h_conv2 = layers.conv2d(
- h_pool1, 64, kernel_size=[5, 5], activation_fn=tf.nn.relu)
- h_pool2 = tf.nn.max_pool(
- h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
- # reshape tensor into a batch of vectors
- h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
- # Densely connected layer with 1024 neurons.
- h_fc1 = layers.dropout(
- layers.fully_connected(
- h_pool2_flat, 1024, activation_fn=tf.nn.relu),
- keep_prob=0.5,
- is_training=mode == tf.contrib.learn.ModeKeys.TRAIN)
- # Compute logits (1 per class) and compute loss.
- logits = layers.fully_connected(h_fc1, 10, activation_fn=None)
- loss = tf.losses.softmax_cross_entropy(target, logits)
- return tf.argmax(logits, 1), loss
- def main(_):
- ps_hosts = FLAGS.ps_hosts.split(",")
- worker_hosts = FLAGS.worker_hosts.split(",")
- # Create a cluster from the parameter server and worker hosts.
- cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
- # Create and start a server for the local task.
- server = tf.train.Server(cluster,
- job_name=FLAGS.job_name,
- task_index=FLAGS.task_index)
- if FLAGS.job_name == "ps":
- server.join()
- elif FLAGS.job_name == "worker":
- ### Download and load MNIST dataset.
- mnist = learn.datasets.load_dataset('mnist')
- # Assigns ops to the local worker by default.
- with tf.device(tf.train.replica_device_setter(
- worker_device="/job:worker/task:%d" % FLAGS.task_index,
- cluster=cluster)):
- # Build model...
- with tf.name_scope('input'):
- image = tf.placeholder(tf.float32, [None, 784], name='image')
- label = tf.placeholder(tf.float32, [None], name='label')
- predict, loss = conv_model(image, label, tf.contrib.learn.ModeKeys.TRAIN)
- opt = tf.train.RMSPropOptimizer(0.01)
- opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=len(worker_hosts),
- total_num_replicas=len(worker_hosts))
- global_step = tf.contrib.framework.get_or_create_global_step()
- train_op = opt.minimize(loss, global_step=global_step)
- # The StopAtStepHook handles stopping after running given steps.
- hooks=[tf.train.StopAtStepHook(last_step=1000000),
- tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss}, every_n_iter=100),
- opt.make_session_run_hook(is_chief=(FLAGS.task_index == 0))]
- # The MonitoredTrainingSession takes care of session initialization,
- # restoring from a checkpoint, saving to a checkpoint, and closing when done
- # or an error occurs.
- with tf.train.MonitoredTrainingSession(master=server.target,
- is_chief=(FLAGS.task_index == 0),
- checkpoint_dir="/tmp/train_logs",
- hooks=hooks,
- config=tf.ConfigProto(log_device_placement=True)) as mon_sess:
- while not mon_sess.should_stop():
- # Run a training step synchronously.
- image_, label_ = mnist.train.next_batch(100)
- mon_sess.run(train_op, feed_dict={image: image_, label: label_})
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.register("type", "bool", lambda v: v.lower() == "true")
- # Flags for defining the tf.train.ClusterSpec
- parser.add_argument(
- "--ps_hosts",
- type=str,
- default="",
- help="Comma-separated list of hostname:port pairs"
- )
- parser.add_argument(
- "--worker_hosts",
- type=str,
- default="",
- help="Comma-separated list of hostname:port pairs"
- )
- parser.add_argument(
- "--job_name",
- type=str,
- default="",
- help="One of 'ps', 'worker'"
- )
- # Flags for defining the tf.train.Server
- parser.add_argument(
- "--task_index",
- type=int,
- default=0,
- help="Index of task within the job"
- )
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement