in tinynn/graph/tracer.py [0:0]
def prepare_torch_overrides_funcs(funcs):
tracked_funcs = []
wrappers = []
if hasattr(torch, 'overrides') and inspect.ismodule(torch.overrides):
all_has_torch_func_names = ['has_torch_function', 'has_torch_function_unary', 'has_torch_function_variadic']
all_handle_func_names = ['handle_torch_function']
has_torch_func_names = []
for n in all_has_torch_func_names:
if hasattr(torch.overrides, n):
has_torch_func_names.append(n)
handle_func_names = []
for n in all_handle_func_names:
if hasattr(torch.overrides, n):
handle_func_names.append(n)
has_torch_funcs = {torch.overrides: has_torch_func_names}
handle_funcs = {torch.overrides: handle_func_names}
for ns in funcs.keys():
if ns == torch.Tensor:
if hasattr(torch, '_tensor') and inspect.ismodule(torch._tensor):
ns = torch._tensor
else:
ns = sys.modules['torch.tensor']
ns_has_torch_func_names = []
for k in has_torch_func_names:
if hasattr(ns, k):
ns_has_torch_func_names.append(k)
ns_handle_func_names = []
for k in handle_func_names:
if hasattr(ns, k):
ns_handle_func_names.append(k)
if len(ns_has_torch_func_names) > 0:
has_torch_funcs.update({ns: ns_has_torch_func_names})
if len(ns_handle_func_names) > 0:
handle_funcs.update({ns: ns_handle_func_names})
tracked_funcs.extend((has_torch_funcs, handle_funcs))
wrappers.extend((new_has_torch_func_gen, new_handle_func_gen))
return tracked_funcs, wrappers