Advertisement
Guest User

Untitled

a guest
Jan 19th, 2017
62
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.16 KB | None | 0 0
  1. def create_eval_ops(model_input, pred_y, all_triples, eval_triples, n_entity,
  2. top_k, idx_1=0, idx_2=1, idx_3=2):
  3. """ Evaluation operations for any model.
  4.  
  5. For given <h,r> predict t, idx_1 = 0, idx_2 = 1, idx_3 = 2
  6. For given <t,r> predict h, idx_1 = 2, idx_2 = 1, idx_3 = 0
  7.  
  8. :param model_input: N by 3 matrix, each row is a h,r,t pair
  9. :param pred_y: N by ENTITY_VOCAB matrix
  10. :param all_triples: M by 3 matrix, contains all triples in the KG
  11. :param eval_triples: M_{eval} by 3 matrix, contains all triples that will be
  12. evaluated, this is a subset of all_triples. model_input
  13. is a subset of eval_triples where the joint index
  14. model_input[idx_1] and model_input[idx_2] is unique in
  15. model_input
  16. :param n_entity: Number of unique entities in the KG
  17. :param top_k: Parameter of Hits@top_k
  18. :param idx_1: First index of the <?,r> pair
  19. :param idx_2: Second index of the <?,r> pair
  20. :param idx_3: Target index in the h,r,t triple
  21. :return:
  22. """
  23.  
  24. def get_id_mask(hrt, triples):
  25. return tf.logical_and(tf.equal(hrt[idx_1], triples[:, idx_1]),
  26. tf.equal(hrt[idx_2], triples[:, idx_2]))
  27.  
  28. def calculate_metrics(tensors):
  29. # eval_hrt, a 3 element h,r,t triple
  30. eval_hrt = tensors
  31.  
  32. # find the entity_vocab vector row id of the given h,r pair
  33. pred_y_mask = get_id_mask(eval_hrt, model_input)
  34. pred_score = tf.reshape(tf.boolean_mask(pred_y, pred_y_mask), [-1])
  35.  
  36. # score of current tail
  37. target_score = pred_score[eval_hrt[idx_3]]
  38.  
  39. triple_mask = get_id_mask(eval_hrt, all_triples)
  40. # disabling validate_indices will disable duplication check
  41. entity_mask = tf.sparse_to_dense(tf.boolean_mask(all_triples[:, idx_3], triple_mask),
  42. output_shape=[n_entity],
  43. sparse_values=True,
  44. default_value=False,
  45. validate_indices=False)
  46. # After masking, [i,j] will equals to min_score - 1e-5 if it is a positive instance
  47. masked_pred_score = pred_score * tf.cast(tf.logical_not(entity_mask), tf.float32) - \
  48. tf.cast(entity_mask, tf.float32) * 1e30
  49.  
  50. # Count how many entities has a score larger than target
  51. def get_rank(score, entity_scores):
  52. return tf.reduce_sum(tf.cast(tf.greater(score, entity_scores), tf.int32)) + 1
  53.  
  54. unfiltered_rank = get_rank(pred_score, target_score)
  55. filtered_rank = get_rank(masked_pred_score, target_score)
  56.  
  57. unfiltered_hit = tf.where(unfiltered_rank <= top_k, 1, 0)
  58. filtered_hit = tf.where(filtered_rank <= top_k, 1, 0)
  59.  
  60. return tf.stack(
  61. [tf.cast(x, tf.float32) for x in [unfiltered_rank, filtered_rank, unfiltered_hit, filtered_hit]])
  62.  
  63. metrics = tf.reduce_mean(
  64. tf.map_fn(calculate_metrics, eval_triples,
  65. dtype=tf.float32, parallel_iterations=20,
  66. back_prop=False, swap_memory=True),
  67. axis=0)
  68.  
  69. return metrics
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement