Advertisement
mugs

Deep Neural Network in TensorFlow - type along

Jan 18th, 2018
141
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.68 KB | None | 0 0
  1. from tensorflow.examples.tutorials.mnist import input_data
  2. mnist = input_data.read_data_sets(".", one_hot=True, reshape=False)
  3.  
  4. import tensorflow as tf
  5.  
  6. learning_rate = 0.001
  7. training_epochs = 20
  8. batch_size = 128
  9. display_step = 1
  10.  
  11. n_input = 784
  12. n_classes = 10
  13.  
  14. n_hidden_layer = 256
  15.  
  16. weights = {
  17.     'hidden_layer':tf.Variable(tf.random_normal([n_input, n_hidden_layer])),
  18.     'out':tf.Variable(tf.random_normal([n_hidden_layer, n_classes]))
  19. }
  20.  
  21. biases = {
  22.     'hidden_layer': tf.Variable(tf.random_normal([n_input, n_hidden_layer])),
  23.     'out':tf.Variable(tf.random_normal([n_classes]))
  24. }
  25.  
  26. x = tf.placeholder("float", [None, 28, 28, 1])
  27. y = tf.placeholder("float", [None, n_classes])
  28.  
  29. x_flat = tf.reshape(x,[-1, n_input])
  30.  
  31. layer_1 = tf.add(tf.matmul(x_flat, weights['hidden_layer']), biases['hidden_layer'])
  32. layer_1 = tf.nn.relu(layer_1)
  33.  
  34. # failed with -
  35. # InvalidArgumentError: Dimensions must be equal, but are 784 and 256 for 'Add_5' (op: 'Add') with input shapes: [784,10], [256,10].
  36. logits = tf.add(tf.matmul(layer_1, weights['out']), biases['out'])
  37.  
  38. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels = y))
  39. optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)
  40.  
  41. init = tf.global_variables_initializer()
  42. with tf.Session() as sess:
  43.     sess.run(init)
  44.     for epoch in range(training_epochs):
  45.         total_batch = int(mnist.train.num_examples/batch_size)
  46.         for i in range(total_batch):
  47.             batch_x, batch_y = mnist.train.next_batch(batch_size)
  48.             #InvalidArgumentError: Incompatible shapes: [128,256] vs. [784,256]
  49.             sess.run(optimizer, feed_dict = {x:batch_x, y:batch_y})
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement