easy_rec/python/loss/pairwise_loss.py (214 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import tensorflow as tf
from tensorflow.python.ops.losses.losses_impl import compute_weighted_loss
from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
from easy_rec.python.utils.shape_utils import get_shape_list
if tf.__version__ >= '2.0':
tf = tf.compat.v1
def pairwise_loss(labels,
logits,
session_ids=None,
margin=0,
temperature=1.0,
weights=1.0,
name=''):
"""Deprecated Pairwise loss. Also see `pairwise_logistic_loss` below.
Args:
labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
logits: a `Tensor` with shape [batch_size]. e.g. the value of last neuron before activation.
session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
margin: the margin between positive and negative sample pair
temperature: (Optional) The temperature to use for scaling the logits.
weights: sample weights
name: the name of loss
"""
logging.warning(
'The old `pairwise_loss` is being deprecated. '
'Please use the new `pairwise_logistic_loss` or `pairwise_focal_loss`')
loss_name = name if name else 'pairwise_loss'
logging.info('[{}] margin: {}, temperature: {}'.format(
loss_name, margin, temperature))
if temperature != 1.0:
logits /= temperature
pairwise_logits = tf.math.subtract(
tf.expand_dims(logits, -1), tf.expand_dims(logits, 0)) - margin
pairwise_mask = tf.greater(
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
if session_ids is not None:
logging.info('[%s] use session ids' % loss_name)
group_equal = tf.equal(
tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
num_pair = tf.size(pairwise_logits)
tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
if tf.is_numeric_tensor(weights):
logging.info('[%s] use sample weight' % loss_name)
weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
batch_size, _ = get_shape_list(weights, 2)
pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
else:
pairwise_weights = weights
pairwise_pseudo_labels = tf.ones_like(pairwise_logits)
loss = tf.losses.sigmoid_cross_entropy(
pairwise_pseudo_labels, pairwise_logits, weights=pairwise_weights)
# set rank loss to zero if a batch has no positive sample.
# loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss)
return loss
def pairwise_focal_loss(labels,
logits,
session_ids=None,
hinge_margin=None,
gamma=2,
alpha=None,
ohem_ratio=1.0,
temperature=1.0,
weights=1.0,
name=''):
loss_name = name if name else 'pairwise_focal_loss'
assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
logging.info(
'[{}] hinge margin: {}, gamma: {}, alpha: {}, ohem_ratio: {}, temperature: {}'
.format(loss_name, hinge_margin, gamma, alpha, ohem_ratio, temperature))
if temperature != 1.0:
logits /= temperature
pairwise_logits = tf.expand_dims(logits, -1) - tf.expand_dims(logits, 0)
pairwise_mask = tf.greater(
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
if hinge_margin is not None:
hinge_mask = tf.less(pairwise_logits, hinge_margin)
pairwise_mask = tf.logical_and(pairwise_mask, hinge_mask)
if session_ids is not None:
logging.info('[%s] use session ids' % loss_name)
group_equal = tf.equal(
tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
num_pair = tf.size(pairwise_logits)
tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
if tf.is_numeric_tensor(weights):
logging.info('[%s] use sample weight' % loss_name)
weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
batch_size, _ = get_shape_list(weights, 2)
pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
else:
pairwise_weights = weights
pairwise_pseudo_labels = tf.ones_like(pairwise_logits)
loss = sigmoid_focal_loss_with_logits(
pairwise_pseudo_labels,
pairwise_logits,
gamma=gamma,
alpha=alpha,
ohem_ratio=ohem_ratio,
sample_weights=pairwise_weights)
return loss
def pairwise_logistic_loss(labels,
logits,
session_ids=None,
temperature=1.0,
hinge_margin=None,
weights=1.0,
ohem_ratio=1.0,
use_label_margin=False,
name=''):
r"""Computes pairwise logistic loss between `labels` and `logits`, equivalent to RankNet loss.
Definition:
$$
\mathcal{L}(\{y\}, \{s\}) =
\sum_i \sum_j I[y_i > y_j] \log(1 + \exp(-(s_i - s_j)))
$$
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size].
session_ids: a `Tensor` with shape [batch_size]. Session ids of each
sample, used to max GAUC metric. e.g. user_id
temperature: (Optional) The temperature to use for scaling the logits.
hinge_margin: the margin between positive and negative logits
weights: A scalar, a `Tensor` with shape [batch_size] for each sample
ohem_ratio: the percent of hard examples to be mined
use_label_margin: whether to use the diff `label[i]-label[j]` as margin
name: the name of loss
"""
loss_name = name if name else 'pairwise_logistic_loss'
assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
logging.info('[{}] hinge margin: {}, ohem_ratio: {}, temperature: {}'.format(
loss_name, hinge_margin, ohem_ratio, temperature))
if temperature != 1.0:
logits /= temperature
if use_label_margin:
labels /= temperature
pairwise_logits = tf.math.subtract(
tf.expand_dims(logits, -1), tf.expand_dims(logits, 0))
if use_label_margin:
pairwise_logits -= tf.math.subtract(
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
elif hinge_margin is not None:
pairwise_logits -= hinge_margin
pairwise_mask = tf.greater(
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
if session_ids is not None:
logging.info('[%s] use session ids' % loss_name)
group_equal = tf.equal(
tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
num_pair = tf.size(pairwise_logits)
tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
# The following is the same as log(1 + exp(-pairwise_logits)).
losses = tf.nn.relu(-pairwise_logits) + tf.math.log1p(
tf.exp(-tf.abs(pairwise_logits)))
if tf.is_numeric_tensor(weights):
logging.info('[%s] use sample weight' % loss_name)
weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
batch_size, _ = get_shape_list(weights, 2)
pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
else:
pairwise_weights = weights
if ohem_ratio == 1.0:
return compute_weighted_loss(losses, pairwise_weights)
losses = compute_weighted_loss(
losses, pairwise_weights, reduction=tf.losses.Reduction.NONE)
k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
k = tf.to_int32(tf.math.rint(k))
topk = tf.nn.top_k(losses, k)
losses = tf.boolean_mask(topk.values, topk.values > 0)
return tf.reduce_mean(losses)
def pairwise_hinge_loss(labels,
logits,
session_ids=None,
temperature=1.0,
margin=1.0,
weights=1.0,
ohem_ratio=1.0,
label_is_logits=True,
use_label_margin=True,
use_exponent=False,
name=''):
r"""Computes pairwise hinge loss between `labels` and `logits`.
Definition:
$$
\mathcal{L}(\{y\}, \{s\}) =
\sum_i \sum_j I[y_i > y_j] \max(0, 1 - (s_i - s_j))
$$
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size].
session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
temperature: (Optional) The temperature to use for scaling the logits.
margin: the margin between positive and negative logits
weights: A scalar, a `Tensor` with shape [batch_size] for each sample
ohem_ratio: the percent of hard examples to be mined
label_is_logits: Whether `labels` is expected to be a logits tensor.
use_label_margin: whether to use the diff `label[i]-label[j]` as margin
use_exponent: whether to use exponential difference
name: the name of loss
"""
loss_name = name if name else 'pairwise_hinge_loss'
assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
logging.info(
'[{}] margin: {}, ohem_ratio: {}, temperature: {}, use_exponent: {}, label_is_logits: {}, use_label_margin: {}'
.format(loss_name, margin, ohem_ratio, temperature, use_exponent,
label_is_logits, use_label_margin))
if temperature != 1.0:
logits /= temperature
if label_is_logits:
labels /= temperature
if use_exponent:
labels = tf.nn.sigmoid(labels)
logits = tf.nn.sigmoid(labels)
pairwise_logits = tf.math.subtract(
tf.expand_dims(logits, -1), tf.expand_dims(logits, 0))
pairwise_labels = tf.math.subtract(
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
pairwise_mask = tf.greater(pairwise_labels, 0)
if session_ids is not None:
logging.info('[%s] use session ids' % loss_name)
group_equal = tf.equal(
tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
pairwise_labels = tf.boolean_mask(pairwise_labels, pairwise_mask)
num_pair = tf.size(pairwise_logits)
tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
if use_label_margin:
diff = pairwise_labels - pairwise_logits
else:
diff = margin - pairwise_logits
if use_exponent:
threshold = 88.0 # the max value of float32 is 3.4028235e+38
safe_diff = tf.clip_by_value(diff, -threshold, threshold)
losses = tf.nn.relu(tf.exp(safe_diff) - 1.0)
else:
losses = tf.nn.relu(diff)
if tf.is_numeric_tensor(weights):
logging.info('[%s] use sample weight' % loss_name)
weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
batch_size, _ = get_shape_list(weights, 2)
pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
else:
pairwise_weights = weights
if ohem_ratio == 1.0:
return compute_weighted_loss(losses, pairwise_weights)
losses = compute_weighted_loss(
losses, pairwise_weights, reduction=tf.losses.Reduction.NONE)
k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
k = tf.to_int32(tf.math.rint(k))
topk = tf.nn.top_k(losses, k)
losses = tf.boolean_mask(topk.values, topk.values > 0)
return tf.reduce_mean(losses)