def call_module()

in arctic_inference/vllm/ulysses.py [0:0]


    def call_module(self, target: torch.fx.node.Target,
                    args: tuple[torch.fx.node.Argument,
                                ...], kwargs: dict[str, Any]) -> Any:
        assert isinstance(target, str)
        # [Arctic Inference]
        # Since monkeypatching inherits the original class
        # through ArcticPatch class, we lose the access to the original class'
        # super() function. Instead of using super(), we directly invoke call_module
        # from the super class torch.fx.Interpreter of PiecewiseCompileInterpreter.
        # see - v0.9.0.1/compilation/backends.py#L241
        output = torch.fx.Interpreter.call_module(self, target, args, kwargs)

        if target in self.compile_submod_names:
            index = self.compile_submod_names.index(target)
            submod = self.fetch_attr(target)
            # [Arctic Inference]
            # Compiler may create subgraphs with certain symbolic
            # integer values that violates vllm's assumption here:
            # - v0.9.0.1/compilation/base_piecewise_backend.py#L64
            # The index of the significant symbol determines the runtime shape here:
            # - v0.9.0.1/compilation/cuda_piecewise_backend.py#L112
            # The fix is relaxing vllm's original assumption that there is only a
            # single symbolic that determines the shape.We then find the matching 
            # symbol indices.
            sym_shape = self.find_symbolic_shape(args)
            sym_shape_indices = []
            for i, x in enumerate(args):
                if isinstance(x, torch.SymInt):
                    if sym_shape == x:
                        sym_shape_indices.append(i)

            global compilation_start_time
            compiled_graph_for_general_shape = self.vllm_backend.\
                compiler_manager.compile(
                submod,
                args,
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
                graph_index=index,
                num_graphs=len(self.compile_submod_names),
                runtime_shape=None)

            piecewise_backend = resolve_obj_by_qualname(
                current_platform.get_piecewise_backend_cls())
            self.module.__dict__[target] = piecewise_backend(
                submod, self.vllm_config, self.graph_pool, index,
                len(self.compile_submod_names), sym_shape_indices,
                compiled_graph_for_general_shape, self.vllm_backend)

            from vllm.compilation.counter import compilation_counter
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output