in optimum/fx/parallelization/passes.py [0:0]
def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None:
def extract_shape_from_node(node: Node) -> List[Any]:
if "size" in node.kwargs:
return list(node.kwargs["size"])
elif "shape" in node.kwargs:
return list(node.kwargs["shape"])
elif isinstance(node.args[1], tuple):
return list(node.args[1])
else:
return list(node.args[1:])
def update(node: Node, new_shape: List[Any], parallel_axis: int):
if "size" in node.kwargs:
node.update_kwarg("size", tuple(new_shape))
elif "shape" in node.kwargs:
node.update_kwarg("shape", tuple(new_shape))
elif isinstance(node.args[1], tuple):
node.update_arg(1, tuple(new_shape))
else:
node.update_arg(parallel_axis + 1, shape[parallel_axis])
parallel_axis = ParallelAxisSolverPass.get_stored_field_info(node, field="parallel_axis")
if parallel_axis is None:
return
shape = extract_shape_from_node(node)
assert parallel_axis < len(shape)
if not isinstance(shape[parallel_axis], int) or shape[parallel_axis] == -1:
return
world_size = ctx.tp_group.size()
assert shape[parallel_axis] % world_size == 0
shape[parallel_axis] = shape[parallel_axis] // world_size
update(node, shape, parallel_axis)