Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # coding=utf-8
- from color_1 import read_and_decode, get_batch, get_test_batch
- import AlexNet
- import cv2
- import os
- import time
- import numpy as np
- import tensorflow as tf
- import AlexNet_train
- import math
- batch_size=128
- num_examples = 1000
- crop_size=56
- def evaluate(test_x, test_y):
- image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
- label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')
- y = AlexNet.inference(image_holder,evaluate,None)
- correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))
- accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
- saver = tf.train.Saver()
- with tf.Session() as sess:
- init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
- coord = tf.train.Coordinator()
- sess.run(init_op)
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- ckpt=tf.train.get_checkpoint_state(AlexNet_train.MODEL_SAVE_PATH)
- if ckpt and ckpt.model_checkpoint_path:
- ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
- global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
- saver.restore(sess, os.path.join(AlexNet_train.MODEL_SAVE_PATH, ckpt_name))
- print('Loading success, global_step is %s' % global_step)
- step=0
- image_batch, label_batch = sess.run([test_x, test_y])
- accuracy_score=sess.run(accuracy,feed_dict={image_holder: image_batch,
- label_holder: label_batch})
- print("After %s training step(s),validation "
- "precision=%g" % (global_step, accuracy_score))
- coord.request_stop()
- coord.join(threads)
- def main(argv=None):
- test_image, test_label = read_and_decode('val.tfrecords')
- test_images, test_labels = get_test_batch(test_image, test_label, batch_size, crop_size)
- evaluate(test_images, test_labels)
- if __name__=='__main__':
- tf.app.run()
- Traceback (most recent call last):
- File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>
- tf.app.run()
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
- _sys.exit(main(_sys.argv[:1] + flags_passthrough))
- File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main
- evaluate(test_images, test_labels)
- File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 45, in evaluate
- label_holder: label_batch})
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
- run_metadata_ptr)
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 965, in _run
- feed_dict_string, options, run_metadata)
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run
- target_list, options, run_metadata)
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call
- raise type(e)(node_def, op, message)
- tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected dimension in the range [-1, 1), but got 1
- [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]]
- Caused by op u'ArgMax_1', defined at:
- File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>
- tf.app.run()
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
- _sys.exit(main(_sys.argv[:1] + flags_passthrough))
- File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main
- evaluate(test_images, test_labels)
- File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 22, in evaluate
- correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 263, in argmax
- return gen_math_ops.arg_max(input, axis, name)
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 168, in arg_max
- name=name)
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
- op_def=op_def)
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op
- original_op=self._default_original_op, op_def=op_def)
- File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__
- self._traceback = _extract_stack()
- InvalidArgumentError (see above for traceback): Expected dimension in the range [-1, 1), but got 1
- [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]]
Add Comment
Please, Sign In to add comment