Guest User

Untitled

a guest
Jun 30th, 2017
347
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.16 KB | None | 0 0
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. from tensorflow.python.client import device_lib
  4. from tensorflow.python.ops import variable_scope
  5.  
  6.  
  7.  
  8. def get_available_gpus():
  9.     local_device_protos = device_lib.list_local_devices()
  10.     return [x.name for x in local_device_protos if x.device_type == 'GPU']
  11.  
  12.  
  13.  
  14. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  15.  
  16.  
  17.  
  18. gpus = get_available_gpus()
  19. trainers = []
  20. accs = []
  21. xs = []
  22. ys = []
  23.  
  24.  
  25.  
  26. for i_, gpu_id in enumerate(gpus):
  27.  
  28.     with tf.device(gpu_id):
  29.         # [Build graph in here.]
  30.  
  31.         with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=i_>0):
  32.  
  33.             x = tf.placeholder(tf.float32, [None, 784])
  34.             W = tf.Variable(tf.zeros([784, 10]))
  35.             b = tf.Variable(tf.zeros([10]))
  36.             y = tf.nn.softmax(tf.matmul(x, W) + b)
  37.             y_ = tf.placeholder(tf.float32, [None, 10])
  38.  
  39.             cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
  40.  
  41.             train_step = tf.train.GradientDescentOptimizer(0.05).minimize(cross_entropy)
  42.  
  43.             correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
  44.             accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  45.  
  46.             trainers += [train_step]
  47.             accs += [accuracy]
  48.             xs += [x]
  49.             ys += [y_]
  50.  
  51.            
  52. # Start an interactive tensorflow session
  53. sess = tf.Session()
  54.  
  55. # Initialize all variables associated with this session
  56. sess.run(tf.initialize_all_variables())
  57.  
  58.  
  59.  
  60. gpu_num = 15
  61.  
  62. with tf.device(gpus[gpu_num]):
  63.  
  64.     for _ in range(20000):
  65.         batch_xs, batch_ys = mnist.train.next_batch(100)
  66.         sess.run(trainers[gpu_num], feed_dict={xs[gpu_num]: batch_xs, ys[gpu_num]: batch_ys})
  67.         break
  68.        
  69.        
  70.        
  71. gpu_num = 15
  72.  
  73. print "Accuracy on gpu 15"
  74. print(sess.run(accs[gpu_num], feed_dict={xs[gpu_num]: mnist.test.images, ys[gpu_num]: mnist.test.labels}))
  75. print
  76.  
  77.  
  78.  
  79. gpu_num = 14
  80.  
  81. print "Accuracy on gpu 14"
  82. print(sess.run(accs[gpu_num], feed_dict={xs[gpu_num]: mnist.test.images, ys[gpu_num]: mnist.test.labels}))
Advertisement
Add Comment
Please, Sign In to add comment