def is_cross_entropy_parallel_compatible()

in optimum/fx/parallelization/utils.py [0:0]


def is_cross_entropy_parallel_compatible(node: Node) -> bool:
    """
    For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing.
    """
    if node.op == "call_function":
        weight = node.kwargs.get("weight", None)
        ignore_index = node.kwargs.get("ignore_index", -100)
        label_smoothing = node.kwargs.get("label_smoothing", 0.0)
        if len(node.args) > 2 and weight is None:
            weight = node.args[2]
        if len(node.args) > 4 and ignore_index == -100:
            ignore_index = node.args[4]
        if len(node.args) > 7 and label_smoothing == 0.0:
            label_smoothing = node.args[7]

        return weight is None and ignore_index == -100 and label_smoothing == 0.0

    elif node.op == "call_module":
        mod: nn.CrossEntropyLoss = node.graph.owning_module.get_submodule(node.target)
        weight, label_smoothing, ignore_index = mod.weight, mod.label_smoothing, mod.ignore_index
        return weight is None and ignore_index == -100 and label_smoothing == 0.0

    return False