def prepare_torch_overrides_funcs()

in tinynn/graph/tracer.py [0:0]


def prepare_torch_overrides_funcs(funcs):
    tracked_funcs = []
    wrappers = []
    if hasattr(torch, 'overrides') and inspect.ismodule(torch.overrides):
        all_has_torch_func_names = ['has_torch_function', 'has_torch_function_unary', 'has_torch_function_variadic']
        all_handle_func_names = ['handle_torch_function']

        has_torch_func_names = []
        for n in all_has_torch_func_names:
            if hasattr(torch.overrides, n):
                has_torch_func_names.append(n)

        handle_func_names = []
        for n in all_handle_func_names:
            if hasattr(torch.overrides, n):
                handle_func_names.append(n)

        has_torch_funcs = {torch.overrides: has_torch_func_names}
        handle_funcs = {torch.overrides: handle_func_names}

        for ns in funcs.keys():
            if ns == torch.Tensor:
                if hasattr(torch, '_tensor') and inspect.ismodule(torch._tensor):
                    ns = torch._tensor
                else:
                    ns = sys.modules['torch.tensor']

            ns_has_torch_func_names = []
            for k in has_torch_func_names:
                if hasattr(ns, k):
                    ns_has_torch_func_names.append(k)

            ns_handle_func_names = []
            for k in handle_func_names:
                if hasattr(ns, k):
                    ns_handle_func_names.append(k)

            if len(ns_has_torch_func_names) > 0:
                has_torch_funcs.update({ns: ns_has_torch_func_names})

            if len(ns_handle_func_names) > 0:
                handle_funcs.update({ns: ns_handle_func_names})

        tracked_funcs.extend((has_torch_funcs, handle_funcs))
        wrappers.extend((new_has_torch_func_gen, new_handle_func_gen))

    return tracked_funcs, wrappers