def handle_hard_coded_axis_param()

in optimum/fx/parallelization/passes.py [0:0]


    def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None:
        def extract_shape_from_node(node: Node) -> List[Any]:
            if "size" in node.kwargs:
                return list(node.kwargs["size"])
            elif "shape" in node.kwargs:
                return list(node.kwargs["shape"])
            elif isinstance(node.args[1], tuple):
                return list(node.args[1])
            else:
                return list(node.args[1:])

        def update(node: Node, new_shape: List[Any], parallel_axis: int):
            if "size" in node.kwargs:
                node.update_kwarg("size", tuple(new_shape))
            elif "shape" in node.kwargs:
                node.update_kwarg("shape", tuple(new_shape))
            elif isinstance(node.args[1], tuple):
                node.update_arg(1, tuple(new_shape))
            else:
                node.update_arg(parallel_axis + 1, shape[parallel_axis])

        parallel_axis = ParallelAxisSolverPass.get_stored_field_info(node, field="parallel_axis")
        if parallel_axis is None:
            return

        shape = extract_shape_from_node(node)
        assert parallel_axis < len(shape)
        if not isinstance(shape[parallel_axis], int) or shape[parallel_axis] == -1:
            return
        world_size = ctx.tp_group.size()
        assert shape[parallel_axis] % world_size == 0
        shape[parallel_axis] = shape[parallel_axis] // world_size
        update(node, shape, parallel_axis)