Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def as_keras_metric(method):
- import functools
- from keras import backend as K
- import tensorflow as tf
- @functools.wraps(method)
- def wrapper(self, args, **kwargs):
- """ Wrapper for turning tensorflow metrics into keras metrics """
- value, update_op = method(self, args, **kwargs)
- K.get_session().run(tf.local_variables_initializer())
- with tf.control_dependencies([update_op]):
- value = tf.identity(value)
- return value
- return wrapper
- @as_keras_metric
- def bmac_metric(Y_true, Y_pred):
- return tf.metrics.mean_per_class_accuracy(tf.argmax(Y_true, axis=1), tf.argmax(Y_pred, axis=1), 3)
Add Comment
Please, Sign In to add comment