in arctic_inference/vllm/ulysses.py [0:0]
def call_module(self, target: torch.fx.node.Target,
args: tuple[torch.fx.node.Argument,
...], kwargs: dict[str, Any]) -> Any:
assert isinstance(target, str)
# [Arctic Inference]
# Since monkeypatching inherits the original class
# through ArcticPatch class, we lose the access to the original class'
# super() function. Instead of using super(), we directly invoke call_module
# from the super class torch.fx.Interpreter of PiecewiseCompileInterpreter.
# see - v0.9.0.1/compilation/backends.py#L241
output = torch.fx.Interpreter.call_module(self, target, args, kwargs)
if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)
submod = self.fetch_attr(target)
# [Arctic Inference]
# Compiler may create subgraphs with certain symbolic
# integer values that violates vllm's assumption here:
# - v0.9.0.1/compilation/base_piecewise_backend.py#L64
# The index of the significant symbol determines the runtime shape here:
# - v0.9.0.1/compilation/cuda_piecewise_backend.py#L112
# The fix is relaxing vllm's original assumption that there is only a
# single symbolic that determines the shape.We then find the matching
# symbol indices.
sym_shape = self.find_symbolic_shape(args)
sym_shape_indices = []
for i, x in enumerate(args):
if isinstance(x, torch.SymInt):
if sym_shape == x:
sym_shape_indices.append(i)
global compilation_start_time
compiled_graph_for_general_shape = self.vllm_backend.\
compiler_manager.compile(
submod,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None)
piecewise_backend = resolve_obj_by_qualname(
current_platform.get_piecewise_backend_cls())
self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape, self.vllm_backend)
from vllm.compilation.counter import compilation_counter
compilation_counter.num_piecewise_capturable_graphs_seen += 1
return output