in optimum/fx/parallelization/utils.py [0:0]
def is_cross_entropy_parallel_compatible(node: Node) -> bool:
"""
For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing.
"""
if node.op == "call_function":
weight = node.kwargs.get("weight", None)
ignore_index = node.kwargs.get("ignore_index", -100)
label_smoothing = node.kwargs.get("label_smoothing", 0.0)
if len(node.args) > 2 and weight is None:
weight = node.args[2]
if len(node.args) > 4 and ignore_index == -100:
ignore_index = node.args[4]
if len(node.args) > 7 and label_smoothing == 0.0:
label_smoothing = node.args[7]
return weight is None and ignore_index == -100 and label_smoothing == 0.0
elif node.op == "call_module":
mod: nn.CrossEntropyLoss = node.graph.owning_module.get_submodule(node.target)
weight, label_smoothing, ignore_index = mod.weight, mod.label_smoothing, mod.ignore_index
return weight is None and ignore_index == -100 and label_smoothing == 0.0
return False