Advertisement
Guest User

Untitled

a guest
Aug 24th, 2016
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.98 KB | None | 0 0
  1. def get_data(image_name_list, num_epochs, scope_name, num_class = NUM_CLASS):
  2. with tf.variable_scope(scope_name) as scope:
  3. images_path = [os.path.join(DATASET_DIR, i+'.jpg') for i in image_name_list]
  4. gts_path = [os.path.join(GT_DIR, i+'.png') for i in image_name_list]
  5. seed = random.randint(0, 2147483647)
  6. image_name_queue = tf.train.string_input_producer(images_path, num_epochs=num_epochs, shuffle=False, seed = seed)
  7. gt_name_queue = tf.train.string_input_producer(gts_path, num_epochs=num_epochs, shuffle=False, seed = seed)
  8. reader = tf.WholeFileReader()
  9. image_key, image_value = reader.read(image_name_queue)
  10. my_image = tf.image.decode_jpeg(image_value)
  11. my_image = tf.cast(my_image, tf.float32)
  12. my_image = tf.expand_dims(my_image, 0)
  13. gt_key, gt_value = reader.read(gt_name_queue)
  14. # gt stands for ground truth
  15. my_gt = tf.cast(tf.image.decode_png(gt_value, channels = 1), tf.float32)
  16. my_gt = tf.one_hot(tf.cast(my_gt, tf.int32), NUM_CLASS)
  17. return my_image, my_gt
  18.  
  19. train_image, train_gt = get_data(train_files, NUM_EPOCH, 'training')
  20. val_image, val_gt = get_data(val_files, NUM_EPOCH, 'validation')
  21. with tf.variable_scope('FCN16') as scope:
  22. train_vgg16_fcn = fcn16_vgg.FCN16VGG()
  23. train_vgg16_fcn.build(train_image, train=True, num_classes=NUM_CLASS, keep_prob = KEEP_PROB)
  24. scope.reuse_variables()
  25. val_vgg16_fcn = fcn16_vgg.FCN16VGG()
  26. val_vgg16_fcn.build(val_image, train=False, num_classes=NUM_CLASS, keep_prob = 1)
  27. """
  28. Define the loss, evaluation metric, summary, saver in the computation graph. Initialize variables and start a session.
  29. """
  30. for epoch in range(starting_epoch, NUM_EPOCH):
  31. for i in range(train_num):
  32. _, loss_value, shape = sess.run([train_op, train_entropy_loss, tf.shape(train_image)])
  33. print shape
  34. for i in range(val_num):
  35. loss_value, shape = sess.run([val_entropy_loss, tf.shape(val_image)])
  36. print shape
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement