tzrec/loss/jrc_loss.py (64 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from torch.nn import CrossEntropyLoss
from torch.nn.modules.loss import _Loss
@torch.fx.wrap
def _label_mask(labels: torch.Tensor) -> torch.Tensor:
return torch.eye(labels.size(0), dtype=torch.int64, device=labels.device)
@torch.fx.wrap
def _diag_index(labels: torch.Tensor) -> torch.Tensor:
return torch.arange(0, labels.size(0), dtype=torch.int64, device=labels.device)
class JRCLoss(_Loss):
"""Positive sample probability competes in session.
https://arxiv.org/abs/2208.06164
Args:
alpha (float): cross entropy loss weight.
reduction (str, optional): Specifies the reduction to apply to the
output: `none` | `mean`. `none`: no reduction will be applied
, `mean`: the weighted mean of the output is taken.
"""
def __init__(
self,
alpha: float = 0.5,
reduction: str = "mean",
) -> None:
super().__init__()
self._alpha = alpha
self._reduction = reduction
self._ce_loss = CrossEntropyLoss(reduction=reduction)
def forward(
self,
logits: Tensor,
labels: Tensor,
session_ids: Tensor,
) -> Tensor:
"""JRC loss.
Args:
logits: a `Tensor` with shape [batch_size, 2].
labels: a `Tensor` with shape [batch_size].
session_ids: a `Tensor` with shape [batch_size].
Return:
loss: a `Tensor` with shape [batch_size] if reduction is 'none',
otherwise with shape ().
"""
ce_loss = self._ce_loss(logits, labels)
batch_size = labels.shape[0]
mask = torch.eq(session_ids.unsqueeze(1), session_ids.unsqueeze(0)).float()
diag_index = _diag_index(labels)
logits_neg, logits_pos = logits[:, 0], logits[:, 1]
diag = _label_mask(labels)
pos_num = torch.sum(labels)
neg_num = batch_size - pos_num
# first, we calculate pos sample loss in during the session.
pos_mask_index = torch.where(labels == 1.0)[0]
pos_diag_label = torch.index_select(diag_index, 0, pos_mask_index)
# pyre-ignore [6]
logits_pos = logits_pos.unsqueeze(0).tile([pos_num, 1])
pos_session_mask = torch.index_select(mask, 0, pos_mask_index)
# pyre-ignore [6]
y_pos = labels.unsqueeze(0).tile([pos_num, 1])
diag_pos = torch.index_select(diag, 0, pos_mask_index)
# we mask not in the same session, is diagonal and is positive.
logits_pos = (
logits_pos + ((1 - pos_session_mask) + (1 - diag_pos) * y_pos) * -1e9
)
loss_pos = self._ce_loss(logits_pos, pos_diag_label)
# next, we calculate neg sample loss in during the session.
neg_mask_index = torch.where(labels == 0.0)[0]
neg_diag_label = torch.index_select(diag_index, 0, neg_mask_index)
logits_neg = logits_neg.unsqueeze(0).tile([neg_num, 1])
neg_session_mask = torch.index_select(mask, 0, neg_mask_index)
y_neg = (1 - labels).unsqueeze(0).tile([neg_num, 1])
diag_neg = torch.index_select(diag, 0, neg_mask_index)
# we mask not in the same session, is diagonal and is negative.
logits_neg = (
logits_neg + ((1 - neg_session_mask) + (1 - diag_neg) * y_neg) * -1e9
)
loss_neg = self._ce_loss(logits_neg, neg_diag_label)
if self._reduction != "none":
loss_pos = loss_pos * pos_num / batch_size
loss_neg = loss_neg * neg_num / batch_size
ge_loss = loss_pos + loss_neg
else:
ge_loss = torch.zeros_like(labels, dtype=torch.float)
ge_loss.index_put_(torch.where(labels == 1.0), loss_pos)
ge_loss.index_put_(torch.where(labels == 0.0), loss_neg)
loss = self._alpha * ce_loss + (1 - self._alpha) * ge_loss
# pyre-ignore [7]
return loss