Advertisement
Guest User

LogRegSaver

a guest
Jan 26th, 2017
241
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.68 KB | None | 0 0
  1. '''
  2. A logistic regression learning algorithm example using TensorFlow library.
  3. This example is using the MNIST database of handwritten digits
  4. (http://yann.lecun.com/exdb/mnist/)
  5.  
  6. Author: Aymeric Damien
  7. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  8. '''
  9.  
  10. from __future__ import print_function
  11.  
  12. import tensorflow as tf
  13.  
  14. # Import MNIST data
  15. from tensorflow.examples.tutorials.mnist import input_data
  16. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  17.  
  18. # Parameters
  19. learning_rate = 0.01
  20. training_epochs = 20
  21. batch_size = 100
  22. display_step = 1
  23.  
  24. # tf Graph Input
  25. x = tf.placeholder(tf.float32, [None, 784])  # mnist data image of shape 28*28=784
  26. y = tf.placeholder(tf.float32, [None, 10])  # 0-9 digits recognition => 10 classes
  27.  
  28. # Set model weights
  29. with tf.variable_scope("my"):
  30.     W = tf.Variable(tf.zeros([784, 10]), name = "W")
  31.     b = tf.Variable(tf.zeros([10]), name = "b")
  32.  
  33.  
  34. # Construct model
  35. pred = tf.nn.softmax(tf.matmul(x, W) + b)   # Softmax
  36.  
  37. # Minimize error using cross entropy
  38. cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
  39. # Gradient Descent
  40. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
  41.  
  42. # Initializing the variables
  43. init_op = tf.global_variables_initializer()
  44. saver = tf.train.Saver()
  45.  
  46. # Launch the graph
  47. with tf.Session() as sess:
  48.     sess.run(init_op)
  49.  
  50.     # Training cycle
  51.     for epoch in range(training_epochs):
  52.         avg_cost = 0.
  53.         total_batch = int(mnist.train.num_examples/batch_size)
  54.         # Loop over all batches
  55.         for i in range(total_batch):
  56.             batch_xs, batch_ys = mnist.train.next_batch(batch_size)
  57.             # Run optimization op (backprop) and cost op (to get loss value)
  58.             _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
  59.                                                           y: batch_ys})
  60.             # Compute average loss
  61.             avg_cost += c / total_batch
  62.         # Display logs per epoch step
  63.         if (epoch+1) % display_step == 0:
  64.             print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))
  65.  
  66.     print("Optimization Finished!")
  67.  
  68.     # Test model
  69.     correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  70.     # Calculate accuracy
  71.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  72.  
  73.     save_path = saver.save(sess, "/Users/mac/PycharmProjects/untitled1/MyModel", write_meta_graph=True)
  74.     print("Model saved in file: %s" % save_path)
  75.     print("Accuracy_old:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  76.  
  77.     # Save the variables to disk.
  78.  
  79.     def restore():
  80.  
  81.         new_saver = tf.train.import_meta_graph('MyModel.meta')
  82.         new_saver.restore(sess, tf.train.latest_checkpoint('./'))
  83.         all_vars = tf.get_collection('vars')
  84.         for v in all_vars:
  85.             v_ = sess.run(v)
  86.             print(v_)
  87.  
  88.         with tf.variable_scope("my"):
  89.             tf.get_variable_scope().reuse_variables()
  90.             idx = tf.constant([0])
  91.             temp_var = tf.get_variable("W")  #[v[0] for v in all_vars if v.name == "W"]
  92.             size_1 = tf.gather(temp_var, idx)
  93.             size_2 = tf.get_variable("b")
  94.  
  95.         ones_mask = tf.Variable(tf.ones([size_1, size_2]))
  96.         index_num = tf.Variable(tf.random_uniform([size_1, ]))
  97.         indexNum = tf.cast(index_num, tf.int64)
  98.         update = tf.scatter_update(ones_mask, indexNum, tf.zeros([size_1, size_2]))
  99.  
  100.         assign_op = W.assign(tf.mul(W, update))
  101.         sess.run(tf.global_variables_initializer())
  102.         sess.run(assign_op)
  103.         print("Accuracy_new:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  104.  
  105.     restore()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement