Advertisement
Guest User

LogReg

a guest
Jan 31st, 2017
166
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.60 KB | None | 0 0
  1. '''
  2. A logistic regression learning algorithm example using TensorFlow library.
  3. This example is using the MNIST database of handwritten digits
  4. (http://yann.lecun.com/exdb/mnist/)
  5.  
  6. Author: Aymeric Damien
  7. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  8. '''
  9.  
  10. from __future__ import print_function
  11.  
  12. import tensorflow as tf
  13. import random
  14. # Import MNIST data
  15. from tensorflow.examples.tutorials.mnist import input_data
  16. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  17.  
  18. # Parameters
  19. learning_rate = 0.01
  20. training_epochs = 20
  21. batch_size = 100
  22. display_step = 1
  23.  
  24. # tf Graph Input
  25. x = tf.placeholder(tf.float32, [None, 784])  # mnist data image of shape 28*28=784
  26. y = tf.placeholder(tf.float32, [None, 10])  # 0-9 digits recognition => 10 classes
  27.  
  28. # Set model weights
  29. with tf.variable_scope("foo"):
  30.     W = tf.get_variable("W", initializer=tf.zeros([784,10]))
  31.     b = tf.get_variable("b", initializer=tf.zeros([10]))
  32.  
  33.  
  34. # Construct model
  35. pred = tf.nn.softmax(tf.matmul(x, W) + b)   # Softmax
  36.  
  37. # Minimize error using cross entropy
  38. cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
  39. # Gradient Descent
  40. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
  41.  
  42. # Initializing the variables
  43. init_op = tf.global_variables_initializer()
  44. saver = tf.train.Saver()
  45.  
  46. # Launch the graph
  47.  
  48. with tf.Session() as sess:
  49.     sess.run(init_op)
  50. # Training cycle
  51.     for epoch in range(training_epochs):
  52.         avg_cost = 0.
  53.         total_batch = int(mnist.train.num_examples/batch_size)
  54.         # Loop over all batches
  55.         for i in range(total_batch):
  56.             batch_xs, batch_ys = mnist.train.next_batch(batch_size)
  57.             # Run optimization op (backprop) and cost op (to get loss value)
  58.             _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
  59.                                                           y: batch_ys})
  60.             # Compute average loss
  61.             avg_cost += c / total_batch
  62.         # Display logs per epoch step
  63.         if (epoch+1) % display_step == 0:
  64.             print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))
  65.  
  66.     print("Optimization Finished!")
  67.  
  68.     # Test model
  69.     correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  70.     # Calculate accuracy
  71.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  72.  
  73.     save_path = saver.save(sess, '/Users/mac/PycharmProjects/untitled1/MyModel2', write_meta_graph=True)
  74.     print("Model saved in file: %s" % save_path)
  75.     print("Accuracy_old:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement