Guest User

Untitled

a guest
Nov 20th, 2017
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.22 KB | None | 0 0
  1. # coding=utf-8
  2. from color_1 import read_and_decode, get_batch, get_test_batch
  3. import AlexNet
  4. import cv2
  5. import os
  6. import time
  7. import numpy as np
  8. import tensorflow as tf
  9. import AlexNet_train
  10. import math
  11.  
  12. batch_size=128
  13. num_examples = 1000
  14. crop_size=56
  15.  
  16. def evaluate(test_x, test_y):
  17. image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
  18. label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')
  19.  
  20. y = AlexNet.inference(image_holder,evaluate,None)
  21.  
  22. correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))
  23. accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  24. saver = tf.train.Saver()
  25. with tf.Session() as sess:
  26. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  27. coord = tf.train.Coordinator()
  28. sess.run(init_op)
  29. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  30. ckpt=tf.train.get_checkpoint_state(AlexNet_train.MODEL_SAVE_PATH)
  31. if ckpt and ckpt.model_checkpoint_path:
  32. ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
  33. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  34. saver.restore(sess, os.path.join(AlexNet_train.MODEL_SAVE_PATH, ckpt_name))
  35. print('Loading success, global_step is %s' % global_step)
  36. step=0
  37.  
  38. image_batch, label_batch = sess.run([test_x, test_y])
  39. accuracy_score=sess.run(accuracy,feed_dict={image_holder: image_batch,
  40. label_holder: label_batch})
  41. print("After %s training step(s),validation "
  42. "precision=%g" % (global_step, accuracy_score))
  43. coord.request_stop()
  44. coord.join(threads)
  45.  
  46. def main(argv=None):
  47. test_image, test_label = read_and_decode('val.tfrecords')
  48.  
  49. test_images, test_labels = get_test_batch(test_image, test_label, batch_size, crop_size)
  50.  
  51. evaluate(test_images, test_labels)
  52.  
  53.  
  54. if __name__=='__main__':
  55. tf.app.run()
  56.  
  57. Traceback (most recent call last):
  58. File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>
  59. tf.app.run()
  60. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
  61. _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  62. File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main
  63. evaluate(test_images, test_labels)
  64. File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 45, in evaluate
  65. label_holder: label_batch})
  66. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
  67. run_metadata_ptr)
  68. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 965, in _run
  69. feed_dict_string, options, run_metadata)
  70. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run
  71. target_list, options, run_metadata)
  72. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call
  73. raise type(e)(node_def, op, message)
  74. tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected dimension in the range [-1, 1), but got 1
  75. [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]]
  76.  
  77. Caused by op u'ArgMax_1', defined at:
  78. File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>
  79. tf.app.run()
  80. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
  81. _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  82. File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main
  83. evaluate(test_images, test_labels)
  84. File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 22, in evaluate
  85. correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))
  86. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 263, in argmax
  87. return gen_math_ops.arg_max(input, axis, name)
  88. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 168, in arg_max
  89. name=name)
  90. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
  91. op_def=op_def)
  92. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op
  93. original_op=self._default_original_op, op_def=op_def)
  94. File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__
  95. self._traceback = _extract_stack()
  96.  
  97. InvalidArgumentError (see above for traceback): Expected dimension in the range [-1, 1), but got 1
  98. [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]]
Add Comment
Please, Sign In to add comment