Advertisement
Guest User

Untitled

a guest
Oct 21st, 2019
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.16 KB | None | 0 0
  1. """Loss functions for semi-supervised learning."""
  2.  
  3. import tensorflow as tf
  4.  
  5.  
  6. def ssl_binary_crossentropy(y_true, y_pred, missing_indicator=-1,
  7. from_logits=False):
  8. """Binary cross-entropy that ignores examples with missing labels.
  9.  
  10. Parameters
  11. ----------
  12. y_true : tensor-like
  13. Integer tensor of binary class labels. Missing entries should have the
  14. value of `missing_indicator`.
  15.  
  16. y_pred : tensor-like
  17. Float tensor of logits or probabilities for the positive class,
  18. depending on `from_logits`.
  19.  
  20. missing_indicator : integer-like
  21. The value that indicates a missing label in `y_true`.
  22.  
  23. from_logits : bool
  24. Whether `y_pred` contains logits (if True) or probabilities (if False)
  25.  
  26. Returns
  27. -------
  28. tf.Tensor
  29. Scalar tensor of the cross-entropy loss averaged over all non-missing
  30. examples.
  31. """
  32. # Process inputs
  33. y_pred = tf.convert_to_tensor(y_pred)
  34. if not y_pred.dtype.is_floating:
  35. y_pred = tf.dtypes.cast(y_pred, dtype=tf.dtypes.float32)
  36. y_true = tf.convert_to_tensor(y_true, dtype=y_pred.dtype)
  37.  
  38. # True if label is present, False if label is missing
  39. mask = tf.math.not_equal(y_true, missing_indicator)
  40.  
  41. def true_fn():
  42. """Called when at least one example in the batch has a label."""
  43.  
  44. def original_args():
  45. """Called when every example has a label."""
  46. return y_true, y_pred
  47.  
  48. def masked_args():
  49. """Called when at least one example is missing a label."""
  50. return tf.boolean_mask(y_true, mask), tf.boolean_mask(y_pred, mask)
  51.  
  52. pred = tf.reduce_all(mask)
  53. target, output = tf.cond(pred, original_args, masked_args)
  54. loss = tf.keras.backend.binary_crossentropy(target=target,
  55. output=output,
  56. from_logits=from_logits)
  57. return tf.math.reduce_mean(loss)
  58.  
  59. def false_fn():
  60. """Called when no example in the batch has a label."""
  61. return tf.convert_to_tensor(0, dtype=y_pred.dtype)
  62.  
  63. return tf.cond(tf.reduce_any(mask), true_fn, false_fn)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement