def __call__()

in deep_gemm/jit/runtime.py [0:0]


    def __call__(self, *args) -> int:
        # Load SO file
        if self.lib is None or self.args is None:
            self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
            with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
                self.args = eval(f.read())

        # Check args and launch
        assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
        cargs = []
        for arg, (name, dtype) in zip(args, self.args):
            if isinstance(arg, torch.Tensor):
                assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
            else:
                assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
            cargs.append(map_ctype(arg))

        return_code = ctypes.c_int(0)
        self.lib.launch(*cargs, ctypes.byref(return_code))
        return return_code.value