def jrc_loss()

in easy_rec/python/loss/jrc_loss.py [0:0]


def jrc_loss(labels,
             logits,
             session_ids,
             alpha=0.5,
             loss_weight_strategy='fixed',
             sample_weights=1.0,
             same_label_loss=True,
             name=''):
  """Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model.

     https://arxiv.org/abs/2208.06164

  Args:
    labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
    logits: a `Tensor` with shape [batch_size, 2]. e.g. the value of last neuron before activation.
    session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
    alpha: the weight to balance ranking loss and calibration loss
    loss_weight_strategy: str, the loss weight strategy to balancing between ce_loss and ge_loss
    sample_weights: Coefficients for the loss. This must be scalar or broadcastable to
      `labels` (i.e. same rank and each dimension is either 1 or the same).
    same_label_loss: enable ge_loss for sample with same label in a session or not.
    name: the name of loss
  """
  loss_name = name if name else 'jrc_loss'
  logging.info('[{}] alpha: {}, loss_weight_strategy: {}'.format(
      loss_name, alpha, loss_weight_strategy))

  ce_loss = tf.losses.sparse_softmax_cross_entropy(
      labels, logits, weights=sample_weights)

  labels = tf.expand_dims(labels, 1)  # [B, 1]
  labels = tf.concat([1 - labels, labels], axis=1)  # [B, 2]

  batch_size = tf.shape(logits)[0]

  # Mask: shape [B, B], mask[i,j]=1 indicates the i-th sample
  # and j-th sample are in the same context
  mask = tf.equal(
      tf.expand_dims(session_ids, 1), tf.expand_dims(session_ids, 0))
  mask = tf.to_float(mask)

  # Tile logits and label: [B, 2]->[B, B, 2]
  logits = tf.tile(tf.expand_dims(logits, 1), [1, batch_size, 1])
  y = tf.tile(tf.expand_dims(labels, 1), [1, batch_size, 1])

  # Set logits that are not in the same context to -inf
  mask3d = tf.expand_dims(mask, 2)
  y = tf.to_float(y) * mask3d
  logits = logits + (1 - mask3d) * -1e9
  y_neg, y_pos = y[:, :, 0], y[:, :, 1]
  l_neg, l_pos = logits[:, :, 0], logits[:, :, 1]

  if tf.is_numeric_tensor(sample_weights):
    logging.info('[%s] use sample weight' % loss_name)
    weights = tf.expand_dims(tf.cast(sample_weights, tf.float32), 0)
    pairwise_weights = tf.tile(weights, tf.stack([batch_size, 1]))
    y_pos *= pairwise_weights
    y_neg *= pairwise_weights

  # Compute list-wise generative loss -log p(x|y, z)
  if same_label_loss:
    logging.info('[%s] enable same_label_loss' % loss_name)
    loss_pos = -tf.reduce_sum(y_pos * tf.nn.log_softmax(l_pos, axis=0), axis=0)
    loss_neg = -tf.reduce_sum(y_neg * tf.nn.log_softmax(l_neg, axis=0), axis=0)
    ge_loss = tf.reduce_mean(
        (loss_pos + loss_neg) / tf.reduce_sum(mask, axis=0))
  else:
    logging.info('[%s] disable same_label_loss' % loss_name)
    diag = tf.one_hot(tf.range(batch_size), batch_size)
    l_pos = l_pos + (1 - diag) * y_pos * -1e9
    l_neg = l_neg + (1 - diag) * y_neg * -1e9
    loss_pos = -tf.linalg.diag_part(y_pos * tf.nn.log_softmax(l_pos, axis=0))
    loss_neg = -tf.linalg.diag_part(y_neg * tf.nn.log_softmax(l_neg, axis=0))
    ge_loss = tf.reduce_mean(loss_pos + loss_neg)

  tf.summary.scalar('loss/%s_ce' % loss_name, ce_loss)
  tf.summary.scalar('loss/%s_ge' % loss_name, ge_loss)

  # The final JRC model
  if loss_weight_strategy == 'fixed':
    loss = alpha * ce_loss + (1 - alpha) * ge_loss
  elif loss_weight_strategy == 'random_uniform':
    weight = tf.random_uniform([])
    loss = weight * ce_loss + (1 - weight) * ge_loss
    tf.summary.scalar('loss/%s_ce_weight' % loss_name, weight)
    tf.summary.scalar('loss/%s_ge_weight' % loss_name, 1 - weight)
  elif loss_weight_strategy == 'random_normal':
    weights = tf.random_normal([2])
    loss_weight = tf.nn.softmax(weights)
    loss = loss_weight[0] * ce_loss + loss_weight[1] * ge_loss
    tf.summary.scalar('loss/%s_ce_weight' % loss_name, loss_weight[0])
    tf.summary.scalar('loss/%s_ge_weight' % loss_name, loss_weight[1])
  elif loss_weight_strategy == 'random_bernoulli':
    bern = tf.distributions.Bernoulli(probs=0.5, dtype=tf.float32)
    weights = bern.sample(2)
    loss_weight = tf.cond(
        tf.equal(tf.reduce_sum(weights), 1), lambda: weights,
        lambda: tf.convert_to_tensor([0.5, 0.5]))
    loss = loss_weight[0] * ce_loss + loss_weight[1] * ge_loss
    tf.summary.scalar('loss/%s_ce_weight' % loss_name, loss_weight[0])
    tf.summary.scalar('loss/%s_ge_weight' % loss_name, loss_weight[1])
  elif loss_weight_strategy == 'uncertainty':
    uncertainty1 = tf.Variable(
        0, name='%s_ranking_loss_weight' % loss_name, dtype=tf.float32)
    tf.summary.scalar('loss/%s_ranking_uncertainty' % loss_name, uncertainty1)
    uncertainty2 = tf.Variable(
        0, name='%s_calibration_loss_weight' % loss_name, dtype=tf.float32)
    tf.summary.scalar('loss/%s_calibration_uncertainty' % loss_name,
                      uncertainty2)
    loss = tf.exp(-uncertainty1) * ce_loss + 0.5 * uncertainty1
    loss += tf.exp(-uncertainty2) * ge_loss + 0.5 * uncertainty2
  else:
    raise ValueError('Unsupported loss weight strategy `%s` for jrc loss' %
                     loss_weight_strategy)
  if np.isscalar(sample_weights) and sample_weights != 1.0:
    return loss * sample_weights
  return loss