easy_rec/python/loss/zero_inflated_lognormal.py (38 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. """Zero-inflated lognormal loss for lifetime value prediction.""" import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions if tf.__version__ >= '2.0': tf = tf.compat.v1 def zero_inflated_lognormal_pred(logits): """Calculates predicted mean of zero inflated lognormal logits. Arguments: logits: [batch_size, 3] tensor of logits. Returns: positive_probs: [batch_size, 1] tensor of positive probability. preds: [batch_size, 1] tensor of predicted mean. """ logits = tf.convert_to_tensor(logits, dtype=tf.float32) positive_probs = tf.keras.backend.sigmoid(logits[..., :1]) loc = logits[..., 1:2] scale = tf.keras.backend.softplus(logits[..., 2:]) preds = ( positive_probs * tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale))) return positive_probs, preds def zero_inflated_lognormal_loss(labels, logits, name=''): """Computes the zero inflated lognormal loss. Usage with tf.keras API: ```python model = tf.keras.Model(inputs, outputs) model.compile('sgd', loss=zero_inflated_lognormal) ``` Arguments: labels: True targets, tensor of shape [batch_size, 1]. logits: Logits of output layer, tensor of shape [batch_size, 3]. name: the name of loss Returns: Zero inflated lognormal loss value. """ loss_name = name if name else 'ziln_loss' labels = tf.cast(labels, dtype=tf.float32) if labels.shape.ndims == 1: labels = tf.expand_dims(labels, 1) # [B, 1] positive = tf.cast(labels > 0, tf.float32) logits = tf.convert_to_tensor(logits, dtype=tf.float32) logits.shape.assert_is_compatible_with( tf.TensorShape(labels.shape[:-1].as_list() + [3])) positive_logits = logits[..., :1] classification_loss = tf.keras.backend.binary_crossentropy( positive, positive_logits, from_logits=True) classification_loss = tf.keras.backend.mean(classification_loss) tf.summary.scalar('loss/%s_classify' % loss_name, classification_loss) loc = logits[..., 1:2] scale = tf.math.maximum( tf.keras.backend.softplus(logits[..., 2:]), tf.math.sqrt(tf.keras.backend.epsilon())) safe_labels = positive * labels + ( 1 - positive) * tf.keras.backend.ones_like(labels) regression_loss = -tf.keras.backend.mean( positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels)) tf.summary.scalar('loss/%s_regression' % loss_name, regression_loss) return classification_loss + regression_loss