easy_rec/python/loss/listwise_loss.py (108 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging import tensorflow as tf from easy_rec.python.utils.load_class import load_by_path def _list_wise_loss(x, labels, logits, session_ids, label_is_logits): mask = tf.equal(x, session_ids) logits = tf.boolean_mask(logits, mask) labels = tf.boolean_mask(labels, mask) y = tf.nn.softmax(labels) if label_is_logits else labels y_hat = tf.nn.log_softmax(logits) return -tf.reduce_sum(y * y_hat) def _list_prob_loss(x, labels, logits, session_ids): mask = tf.equal(x, session_ids) logits = tf.boolean_mask(logits, mask) labels = tf.boolean_mask(labels, mask) y = labels / tf.reduce_sum(labels) y_hat = tf.nn.log_softmax(logits) return -tf.reduce_sum(y * y_hat) def listwise_rank_loss(labels, logits, session_ids, transform_fn=None, temperature=1.0, label_is_logits=False, scale_logits=False, weights=1.0, name='listwise_loss'): r"""Computes listwise softmax cross entropy loss between `labels` and `logits`. Definition: $$ \mathcal{L}(\{y\}, \{s\}) = \sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} ) $$ Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size]. session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id transform_fn: an affine transformation function of labels temperature: (Optional) The temperature to use for scaling the logits. label_is_logits: Whether `labels` is expected to be a logits tensor. By default, we consider that `labels` encodes a probability distribution. scale_logits: Whether to scale the logits. weights: sample weights name: the name of loss """ loss_name = name if name else 'listwise_rank_loss' logging.info('[{}] temperature: {}, scale logits: {}'.format( loss_name, temperature, scale_logits)) labels = tf.to_float(labels) if scale_logits: with tf.variable_scope(loss_name): w = tf.get_variable( 'scale_w', dtype=tf.float32, shape=(1,), initializer=tf.ones_initializer()) b = tf.get_variable( 'scale_b', dtype=tf.float32, shape=(1,), initializer=tf.zeros_initializer()) logits = logits * tf.abs(w) + b if temperature != 1.0: logits /= temperature if label_is_logits: labels /= temperature if transform_fn is not None: trans_fn = load_by_path(transform_fn) labels = trans_fn(labels) sessions, _ = tf.unique(tf.squeeze(session_ids)) tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions)) losses = tf.map_fn( lambda x: _list_wise_loss(x, labels, logits, session_ids, label_is_logits ), sessions, dtype=tf.float32) if tf.is_numeric_tensor(weights): logging.error('[%s] use unsupported sample weight' % loss_name) return tf.reduce_mean(losses) else: return tf.reduce_mean(losses) * weights def listwise_distill_loss(labels, logits, session_ids, transform_fn=None, temperature=1.0, label_clip_max_value=512, scale_logits=False, weights=1.0, name='listwise_distill_loss'): r"""Computes listwise softmax cross entropy loss between `labels` and `logits`. Definition: $$ \mathcal{L}(\{y\}, \{s\}) = \sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} ) $$ Args: labels: A `Tensor` of the same shape as `logits` representing the rank position of a base model. logits: A `Tensor` with shape [batch_size]. session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id transform_fn: an transformation function of labels. temperature: (Optional) The temperature to use for scaling the logits. label_clip_max_value: clip the labels to this value. scale_logits: Whether to scale the logits. weights: sample weights name: the name of loss """ loss_name = name if name else 'listwise_rank_loss' logging.info('[{}] temperature: {}'.format(loss_name, temperature)) labels = tf.to_float(labels) # supposed to be positions of a teacher model labels = tf.clip_by_value(labels, 1, label_clip_max_value) if transform_fn is not None: trans_fn = load_by_path(transform_fn) labels = trans_fn(labels) else: labels = tf.log1p(label_clip_max_value) - tf.log(labels) if scale_logits: with tf.variable_scope(loss_name): w = tf.get_variable( 'scale_w', dtype=tf.float32, shape=(1,), initializer=tf.ones_initializer()) b = tf.get_variable( 'scale_b', dtype=tf.float32, shape=(1,), initializer=tf.zeros_initializer()) logits = logits * tf.abs(w) + b if temperature != 1.0: logits /= temperature sessions, _ = tf.unique(tf.squeeze(session_ids)) tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions)) losses = tf.map_fn( lambda x: _list_prob_loss(x, labels, logits, session_ids), sessions, dtype=tf.float32) if tf.is_numeric_tensor(weights): logging.error('[%s] use unsupported sample weight' % loss_name) return tf.reduce_mean(losses) else: return tf.reduce_mean(losses) * weights