Advertisement
Guest User

Untitled

a guest
Dec 15th, 2017
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.19 KB | None | 0 0
  1. # Build the data input
  2. print("Bulding data input...")
  3. with tf.device('/cpu:0'):
  4. aug_img, et_img, is_img = ecnn_read_evaluation_images(trainset_size=NUM_STEPS,
  5. batch_size=BATCH_SIZE,
  6. cropped=CROPPED,
  7. base_folder=BASE_FOLDER)
  8.  
  9.  
  10. # Load the ECNN graph, clear its device settings (we want to run the loss function on a CPU)
  11. print("Loading meta graph...")
  12. saver = tf.train.import_meta_graph(os.path.join(save_path, "ecnn.meta"), clear_devices=True)
  13.  
  14. # Start time measurement
  15. start_time = time.time()
  16.  
  17. with tf.Session() as sess:
  18. # Reference: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py
  19. # https://stackoverflow.com/questions/44510024/how-to-restore-my-loss-from-a-saved-meta-graph
  20.  
  21. # Restore variable weights
  22. print("Restoring variable weights...")
  23. saver.restore(sess, os.path.join(save_path, "ecnn"))
  24.  
  25. # Restore saved functions out of the meta graph
  26. print("Restoring functions...")
  27. train_op = tf.get_collection('train_op')[0]
  28. loss_op = tf.get_collection('loss_op')[0]
  29.  
  30. # Merge all summaries
  31. merged = tf.summary.merge_all()
  32.  
  33. # Start input enqueue threads.
  34. print("Starting input queues...")
  35. coord = tf.train.Coordinator()
  36. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  37.  
  38. # Evaluation cycle
  39. print("Starting evaluation...")
  40. for step in range(1, NUM_STEPS + 1):
  41. print(step)
  42. loss = sess.run([loss_op])
  43. print(loss)
  44.  
  45. # When done, ask the threads to stop
  46. coord.request_stop()
  47. # Wait for threads to finish.
  48. coord.join(threads)
  49.  
  50. sess.close()
  51.  
  52. print("--- %s seconds ---" % (time.time() - start_time))
  53.  
  54. # Predict the edge map with the E-CNN
  55. with tf.variable_scope('E-CNN'):
  56. trained_edge = ecnn(aug_img)
  57. tf.summary.image("trained", trained_edge)
  58.  
  59. # Define a loss function
  60. with tf.variable_scope('loss'):
  61. loss_op = tf.losses.mean_squared_error(labels=et_img,
  62. predictions=trained_edge)
  63. tf.summary.scalar("loss", loss_op)
  64. # put op in collection
  65. tf.add_to_collection('loss_op', loss_op)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement