tinynn/graph/configs/gen_funcs_yml.py (163 lines of code) (raw):
import importlib
import inspect
import operator
import torch
import torchvision.ops
import yaml
from torch.overrides import get_ignored_functions, get_overridable_functions, get_testing_overrides
# TODO: Better detection
# Stage 1: Functions in get_overridable_functions()
func_dict = get_overridable_functions()
final_dict = {}
for k, v in func_dict.items():
if not isinstance(k, type) and not inspect.ismodule(k):
continue
if type(v) is list:
for vv in v:
if hasattr(k, '__module__'):
print(k.__module__, k.__name__, vv.__name__)
final_dict.setdefault(f'{k.__module__}.{k.__name__}', [])
if vv.__name__ not in final_dict[f'{k.__module__}.{k.__name__}']:
final_dict[f'{k.__module__}.{k.__name__}'].append(vv.__name__)
elif hasattr(k, '__name__'):
print(k.__name__, vv.__name__)
final_dict.setdefault(k.__name__, [])
if vv.__name__ not in final_dict[k.__name__]:
final_dict[k.__name__].append(vv.__name__)
# Stage 2: Functions in get_testing_overrides()
func_dict = get_testing_overrides()
for f, s in func_dict.items():
qualname = None
module = None
name = None
if hasattr(f, '__qualname__'):
qualname = getattr(f, '__qualname__')
if hasattr(f, '__module__'):
module = getattr(f, '__module__')
if hasattr(f, '__name__'):
name = getattr(f, '__name__')
if module and name:
if module == 'torch._tensor':
module = 'torch.Tensor'
assert hasattr(torch.Tensor, name), f'{module}.{name}'
elif module == 'torch._C._linalg':
assert name.startswith('linalg_')
module = 'torch.linalg'
name = name.replace('linalg_', '')
assert hasattr(torch.linalg, name), f'{module}.{name}'
elif module == 'torch._C.nn':
module = 'torch.nn'
assert hasattr(torch.nn, name), f'{module}.{name}'
elif module in ('torch._C.special', 'torch._C._special'):
module = 'torch.special'
name = name.replace('special_', '')
assert hasattr(torch.special, name), f'{module}.{name}'
elif module in ('torch._C.fft', 'torch._C._fft'):
module = 'torch.fft'
name = name.replace('fft_', '')
assert hasattr(torch.fft, name), f'{module}.{name}'
elif module == 'torch._C._nn':
module = 'torch.nn.functional'
if name == 'log_sigmoid':
name = name.replace('_', '')
assert hasattr(torch.nn.functional, name), f'{module}.{name}'
elif module.startswith('torch._C.'):
print(module, name, 'not recognized')
assert False
fullname = f'{module}.{name}'
elif qualname:
if qualname.startswith('_VariableFunctionsClass.'):
fullname = qualname.replace('_VariableFunctionsClass.', 'torch.')
assert fullname.count('.') == 1
funcname = fullname.split('.')[1]
assert hasattr(torch, funcname)
elif qualname.startswith('torch._tensor.'):
fullname = qualname.replace('torch._tensor.', 'torch.Tensor.')
assert fullname.count('.') == 2
funcname = fullname.split('.')[-1]
assert hasattr(torch.Tensor, funcname)
elif qualname.startswith('_TensorBase.'):
fullname = qualname.replace('_TensorBase.', 'torch.Tensor.')
assert fullname.count('.') == 2
funcname = fullname.split('.')[-1]
assert hasattr(torch.Tensor, funcname)
else:
pass
print(fullname)
rdot = fullname.rfind('.')
ns = fullname[:rdot]
func = fullname[rdot + 1 :]
final_dict.setdefault(ns, [])
final_dict[ns].append(func)
# Stage 3: Functions in get_ignored_functions() for the namespace (torch.nn.functional)
funcs = get_ignored_functions()
for f in funcs:
qualname = None
module = None
name = None
if hasattr(f, '__qualname__'):
qualname = getattr(f, '__qualname__')
if hasattr(f, '__module__'):
module = getattr(f, '__module__')
if hasattr(f, '__name__'):
name = getattr(f, '__name__')
if module == 'torch.nn.functional' and name is not None:
if f.__doc__ is None:
continue
final_dict[module].append(name)
# Stage 4: torch.Tensor + operators
for k, v in operator.__dict__.items():
if inspect.isbuiltin(v):
if hasattr(torch.Tensor, k):
final_dict['torch.Tensor'].append(k)
# Stage 5: torch.tensor -> torch.Tensor
if 'torch.tensor' in final_dict:
v = final_dict.pop('torch.tensor')
final_dict['torch.Tensor'].extend(v)
# Stage 6: torchvision ops
final_dict.setdefault('torchvision.ops', [])
for k, v in torchvision.ops.__dict__.items():
if inspect.isroutine(v) and v.__doc__ is not None:
final_dict['torchvision.ops'].append(k)
def get_scope(ns):
spec = importlib.util.find_spec(ns)
if spec is None:
modules = ns.split('.')
ns = '.'.join(modules[:-1])
typename = modules[-1]
spec = importlib.util.find_spec(ns)
if spec is None:
return None
scope = importlib.import_module(ns)
if hasattr(scope, typename):
scope = getattr(scope, typename)
else:
return None
else:
scope = importlib.import_module(ns)
return scope
ver = torch.__version__
ver = '_'.join(ver.split('.')[:2])
# Stage 7: Functions in new versions may exist in current version
latest = '2_4'
if ver != latest:
with open(f'torch_func_override_{latest}.yml', 'r') as f:
d = yaml.load(f, yaml.SafeLoader)
for k, v in d.items():
if k in final_dict:
scope = get_scope(k)
if scope is None:
continue
for i in v:
if i not in final_dict[k]:
if hasattr(scope, i) and inspect.isroutine(getattr(scope, i)):
final_dict[k].append(i)
print(k, i)
# Stage 8: Functions may have different names (e.g. F.pad)
for k in final_dict:
scope = get_scope(k)
if scope is not None:
for i in dir(scope):
f = getattr(scope, i)
if inspect.isroutine(f):
if f.__name__ in final_dict[k] and i != f.__name__:
final_dict[k].append(i)
print(k, i)
# Stage 9: Make the functions unique and sorted
for k, v in final_dict.items():
vv = list(set(v))
vv.sort()
v.clear()
v.extend(vv)
# Stage 10: Update config
with open(f'torch_func_override_{ver}.yml', 'w') as f:
yaml.dump(final_dict, f)