tinynn/graph/configs/gen_creation_funcs_yml.py (34 lines of code) (raw):
import inspect
import torch
import yaml
import re
from torch.overrides import get_overridable_functions
# TODO: Better detection
# The list is not complete, some functions are missing.
func_dict = get_overridable_functions()
final_dict = {'torch': []}
# Ignore the functions that we cannot translate (e.g. from_numpy)
block_list = ['from_numpy', 'frombuffer']
for k in torch.__dict__:
if k in block_list:
continue
c = getattr(torch, k)
if inspect.isclass(c) and k.endswith('Tensor') and c.__bases__[0] is object:
print(k)
final_dict['torch'].append(k)
elif inspect.isbuiltin(c):
if c not in func_dict[torch] and not k.startswith('_') and not k.endswith('_'):
result_type = 'N/A'
if c.__doc__:
result_type = re.search(r'-> +(.*)', c.__doc__)
if result_type:
result_type = result_type.group(1)
else:
if k.startswith('is_'):
result_type = 'bool [guess]'
elif k.startswith('from_'):
result_type = 'Tensor'
elif k.startswith('set_'):
result_type = 'None [guess]'
if result_type and result_type.endswith('Tensor'):
print(k, result_type)
final_dict['torch'].append(k)
with open('torch_creation_funcs_override.yml', 'w') as f:
yaml.dump(final_dict, f)