in functorch/_src/python_key.py [0:0]
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func in CURRENT_DECOMPOSITION_TABLE:
return CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs)
# Commenting this out for now since it causes some spurious failures (such as error checking)
# if func == aten._local_scalar_dense:
# raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
# "It's likely that this is caused by data-dependent control flow or similar.")
def unwrap_proxy(e):
return e.proxy if isinstance(e, PythonTensor) else e
def unwrap_tensor(e):
return e.elem if isinstance(e, PythonTensor) else e
input_devices = [i.device for i in pytree.tree_flatten(args)[0] +
pytree.tree_flatten(kwargs)[0] if isinstance(i, torch.Tensor)]
output_device = get_output_device(input_devices, func)
proxy_args = pytree.tree_map(unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
proxy_out = func(*proxy_args, **proxy_kwargs)
# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
args[0].proxy = proxy_out
args = pytree.tree_map(unwrap_tensor, args)
kwargs = pytree.tree_map(unwrap_tensor, kwargs)
try:
real_out = func(*args, **kwargs)
except NotImplementedError:
args = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
if isinstance(x, torch.Tensor) else x, args)
kwargs = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
if isinstance(x, torch.Tensor) else x, kwargs)
real_out = func(*args, **kwargs)
def wrap_with_proxy(e, proxy):
# Some ops (like native_batch_norm_backward) return undefined tensors that get
# converted into None in python.
# As the function signature expects tensors, if we directly return these None
# tensors back to C++, we'll error.
if e is None:
e = torch.empty(())
if type(e) == torch.Tensor:
return PythonTensor(e, proxy, output_device)
else:
return e
if isinstance(real_out, tuple):
return tuple([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
elif isinstance(real_out, list):
return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
elif isinstance(real_out, torch.Tensor):
return wrap_with_proxy(real_out, proxy_out)
else:
return real_out