def sharded_cross_entropy_wrapper_fn()

in optimum/fx/parallelization/parallel_layers/loss.py [0:0]


def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup):
    @wraps(sharded_cross_entropy)
    def wrapper(
        sharded_logits: torch.Tensor,
        target: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        size_average: Optional[bool] = None,
        ignore_index: int = -100,
        reduce: Optional[bool] = None,
        reduction: str = "mean",
        label_smoothing: float = 0.0,
    ):
        if weight is not None or ignore_index != -100 or label_smoothing != 0.0:
            raise ValueError(
                "Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation."
            )
        loss: torch.Tensor = sharded_cross_entropy(sharded_logits, target, process_group)

        if size_average is not None or reduce is not None:
            size_average = True if size_average is None else size_average
            reduce = True if reduce is None else reduce

            if size_average and reduce:
                reduction = "mean"
            elif reduce:
                reduction = "sum"
            else:
                reduction = "none"

        if reduction == "mean":
            return loss.mean()
        elif reduction == "sum":
            return loss.sum()
        return loss

    return wrapper