Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Build the data input
- print("Bulding data input...")
- with tf.device('/cpu:0'):
- aug_img, et_img, is_img = ecnn_read_evaluation_images(trainset_size=NUM_STEPS,
- batch_size=BATCH_SIZE,
- cropped=CROPPED,
- base_folder=BASE_FOLDER)
- # Load the ECNN graph, clear its device settings (we want to run the loss function on a CPU)
- print("Loading meta graph...")
- saver = tf.train.import_meta_graph(os.path.join(save_path, "ecnn.meta"), clear_devices=True)
- # Start time measurement
- start_time = time.time()
- with tf.Session() as sess:
- # Reference: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py
- # https://stackoverflow.com/questions/44510024/how-to-restore-my-loss-from-a-saved-meta-graph
- # Restore variable weights
- print("Restoring variable weights...")
- saver.restore(sess, os.path.join(save_path, "ecnn"))
- # Restore saved functions out of the meta graph
- print("Restoring functions...")
- train_op = tf.get_collection('train_op')[0]
- loss_op = tf.get_collection('loss_op')[0]
- # Merge all summaries
- merged = tf.summary.merge_all()
- # Start input enqueue threads.
- print("Starting input queues...")
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- # Evaluation cycle
- print("Starting evaluation...")
- for step in range(1, NUM_STEPS + 1):
- print(step)
- loss = sess.run([loss_op])
- print(loss)
- # When done, ask the threads to stop
- coord.request_stop()
- # Wait for threads to finish.
- coord.join(threads)
- sess.close()
- print("--- %s seconds ---" % (time.time() - start_time))
- # Predict the edge map with the E-CNN
- with tf.variable_scope('E-CNN'):
- trained_edge = ecnn(aug_img)
- tf.summary.image("trained", trained_edge)
- # Define a loss function
- with tf.variable_scope('loss'):
- loss_op = tf.losses.mean_squared_error(labels=et_img,
- predictions=trained_edge)
- tf.summary.scalar("loss", loss_op)
- # put op in collection
- tf.add_to_collection('loss_op', loss_op)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement