Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def create_eval_ops(model_input, pred_y, all_triples, eval_triples, n_entity,
- top_k, idx_1=0, idx_2=1, idx_3=2):
- """ Evaluation operations for any model.
- For given <h,r> predict t, idx_1 = 0, idx_2 = 1, idx_3 = 2
- For given <t,r> predict h, idx_1 = 2, idx_2 = 1, idx_3 = 0
- :param model_input: N by 3 matrix, each row is a h,r,t pair
- :param pred_y: N by ENTITY_VOCAB matrix
- :param all_triples: M by 3 matrix, contains all triples in the KG
- :param eval_triples: M_{eval} by 3 matrix, contains all triples that will be
- evaluated, this is a subset of all_triples. model_input
- is a subset of eval_triples where the joint index
- model_input[idx_1] and model_input[idx_2] is unique in
- model_input
- :param n_entity: Number of unique entities in the KG
- :param top_k: Parameter of Hits@top_k
- :param idx_1: First index of the <?,r> pair
- :param idx_2: Second index of the <?,r> pair
- :param idx_3: Target index in the h,r,t triple
- :return:
- """
- def get_id_mask(hrt, triples):
- return tf.logical_and(tf.equal(hrt[idx_1], triples[:, idx_1]),
- tf.equal(hrt[idx_2], triples[:, idx_2]))
- def calculate_metrics(tensors):
- # eval_hrt, a 3 element h,r,t triple
- eval_hrt = tensors
- # find the entity_vocab vector row id of the given h,r pair
- pred_y_mask = get_id_mask(eval_hrt, model_input)
- pred_score = tf.reshape(tf.boolean_mask(pred_y, pred_y_mask), [-1])
- # score of current tail
- target_score = pred_score[eval_hrt[idx_3]]
- triple_mask = get_id_mask(eval_hrt, all_triples)
- # disabling validate_indices will disable duplication check
- entity_mask = tf.sparse_to_dense(tf.boolean_mask(all_triples[:, idx_3], triple_mask),
- output_shape=[n_entity],
- sparse_values=True,
- default_value=False,
- validate_indices=False)
- # After masking, [i,j] will equals to min_score - 1e-5 if it is a positive instance
- masked_pred_score = pred_score * tf.cast(tf.logical_not(entity_mask), tf.float32) - \
- tf.cast(entity_mask, tf.float32) * 1e30
- # Count how many entities has a score larger than target
- def get_rank(score, entity_scores):
- return tf.reduce_sum(tf.cast(tf.greater(score, entity_scores), tf.int32)) + 1
- unfiltered_rank = get_rank(pred_score, target_score)
- filtered_rank = get_rank(masked_pred_score, target_score)
- unfiltered_hit = tf.where(unfiltered_rank <= top_k, 1, 0)
- filtered_hit = tf.where(filtered_rank <= top_k, 1, 0)
- return tf.stack(
- [tf.cast(x, tf.float32) for x in [unfiltered_rank, filtered_rank, unfiltered_hit, filtered_hit]])
- metrics = tf.reduce_mean(
- tf.map_fn(calculate_metrics, eval_triples,
- dtype=tf.float32, parallel_iterations=20,
- back_prop=False, swap_memory=True),
- axis=0)
- return metrics
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement