SHARE
TWEET

Untitled

a guest Oct 21st, 2019 77 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. """Loss functions for semi-supervised learning."""
  2.  
  3. import tensorflow as tf
  4.  
  5.  
  6. def ssl_binary_crossentropy(y_true, y_pred, missing_indicator=-1,
  7.                             from_logits=False):
  8.     """Binary cross-entropy that ignores examples with missing labels.
  9.  
  10.     Parameters
  11.     ----------
  12.     y_true : tensor-like
  13.         Integer tensor of binary class labels. Missing entries should have the
  14.         value of `missing_indicator`.
  15.  
  16.     y_pred : tensor-like
  17.         Float tensor of logits or probabilities for the positive class,
  18.         depending on `from_logits`.
  19.  
  20.     missing_indicator : integer-like
  21.         The value that indicates a missing label in `y_true`.
  22.  
  23.     from_logits : bool
  24.         Whether `y_pred` contains logits (if True) or probabilities (if False)
  25.  
  26.     Returns
  27.     -------
  28.     tf.Tensor
  29.         Scalar tensor of the cross-entropy loss averaged over all non-missing
  30.         examples.
  31.     """
  32.     # Process inputs
  33.     y_pred = tf.convert_to_tensor(y_pred)
  34.     if not y_pred.dtype.is_floating:
  35.         y_pred = tf.dtypes.cast(y_pred, dtype=tf.dtypes.float32)
  36.     y_true = tf.convert_to_tensor(y_true, dtype=y_pred.dtype)
  37.  
  38.     # True if label is present, False if label is missing
  39.     mask = tf.math.not_equal(y_true, missing_indicator)
  40.  
  41.     def true_fn():
  42.         """Called when at least one example in the batch has a label."""
  43.  
  44.         def original_args():
  45.             """Called when every example has a label."""
  46.             return y_true, y_pred
  47.  
  48.         def masked_args():
  49.             """Called when at least one example is missing a label."""
  50.             return tf.boolean_mask(y_true, mask), tf.boolean_mask(y_pred, mask)
  51.  
  52.         pred = tf.reduce_all(mask)
  53.         target, output = tf.cond(pred, original_args, masked_args)
  54.         loss = tf.keras.backend.binary_crossentropy(target=target,
  55.                                                     output=output,
  56.                                                     from_logits=from_logits)
  57.         return tf.math.reduce_mean(loss)
  58.  
  59.     def false_fn():
  60.         """Called when no example in the batch has a label."""
  61.         return tf.convert_to_tensor(0, dtype=y_pred.dtype)
  62.  
  63.     return tf.cond(tf.reduce_any(mask), true_fn, false_fn)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top