Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- """Loss functions for semi-supervised learning."""
- import tensorflow as tf
- def ssl_binary_crossentropy(y_true, y_pred, missing_indicator=-1,
- from_logits=False):
- """Binary cross-entropy that ignores examples with missing labels.
- Parameters
- ----------
- y_true : tensor-like
- Integer tensor of binary class labels. Missing entries should have the
- value of `missing_indicator`.
- y_pred : tensor-like
- Float tensor of logits or probabilities for the positive class,
- depending on `from_logits`.
- missing_indicator : integer-like
- The value that indicates a missing label in `y_true`.
- from_logits : bool
- Whether `y_pred` contains logits (if True) or probabilities (if False)
- Returns
- -------
- tf.Tensor
- Scalar tensor of the cross-entropy loss averaged over all non-missing
- examples.
- """
- # Process inputs
- y_pred = tf.convert_to_tensor(y_pred)
- if not y_pred.dtype.is_floating:
- y_pred = tf.dtypes.cast(y_pred, dtype=tf.dtypes.float32)
- y_true = tf.convert_to_tensor(y_true, dtype=y_pred.dtype)
- # True if label is present, False if label is missing
- mask = tf.math.not_equal(y_true, missing_indicator)
- def true_fn():
- """Called when at least one example in the batch has a label."""
- def original_args():
- """Called when every example has a label."""
- return y_true, y_pred
- def masked_args():
- """Called when at least one example is missing a label."""
- return tf.boolean_mask(y_true, mask), tf.boolean_mask(y_pred, mask)
- pred = tf.reduce_all(mask)
- target, output = tf.cond(pred, original_args, masked_args)
- loss = tf.keras.backend.binary_crossentropy(target=target,
- output=output,
- from_logits=from_logits)
- return tf.math.reduce_mean(loss)
- def false_fn():
- """Called when no example in the batch has a label."""
- return tf.convert_to_tensor(0, dtype=y_pred.dtype)
- return tf.cond(tf.reduce_any(mask), true_fn, false_fn)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement