def __torch_dispatch__()

in functorch/_src/python_key.py [0:0]


    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        if func in CURRENT_DECOMPOSITION_TABLE:
            return CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs)

        # Commenting this out for now since it causes some spurious failures (such as error checking)
        # if func == aten._local_scalar_dense:
        #     raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
        #                        "It's likely that this is caused by data-dependent control flow or similar.")

        def unwrap_proxy(e):
            return e.proxy if isinstance(e, PythonTensor) else e

        def unwrap_tensor(e):
            return e.elem if isinstance(e, PythonTensor) else e

        input_devices = [i.device for i in pytree.tree_flatten(args)[0] +
                         pytree.tree_flatten(kwargs)[0] if isinstance(i, torch.Tensor)]

        output_device = get_output_device(input_devices, func)

        proxy_args = pytree.tree_map(unwrap_proxy, args)
        proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
        proxy_out = func(*proxy_args, **proxy_kwargs)

        # Kind of a hacky way to test if an op is in-place or not
        if func.__name__[-1] == "_" and func.__name__[0] != "_":
            args[0].proxy = proxy_out
        args = pytree.tree_map(unwrap_tensor, args)
        kwargs = pytree.tree_map(unwrap_tensor, kwargs)

        try:
            real_out = func(*args, **kwargs)
        except NotImplementedError:
            args = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
                                   if isinstance(x, torch.Tensor) else x, args)
            kwargs = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
                                     if isinstance(x, torch.Tensor) else x, kwargs)
            real_out = func(*args, **kwargs)

        def wrap_with_proxy(e, proxy):
            # Some ops (like native_batch_norm_backward) return undefined tensors that get
            # converted into None in python.
            # As the function signature expects tensors, if we directly return these None
            # tensors back to C++, we'll error.
            if e is None:
                e = torch.empty(())
            if type(e) == torch.Tensor:
                return PythonTensor(e, proxy, output_device)
            else:
                return e
        if isinstance(real_out, tuple):
            return tuple([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
        elif isinstance(real_out, list):
            return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
        elif isinstance(real_out, torch.Tensor):
            return wrap_with_proxy(real_out, proxy_out)
        else:
            return real_out