scripts/gen_op_docs.py (84 lines of code) (raw):
import os
import inspect
import re
from tinynn.converter.operators.torch import OPERATOR_CONVERTER_DICT
from tinynn.converter.operators.torch.base import NoTrackOperator, PrimOperatorConverter, TrackConstantOperator
CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))
prim_ops = []
aten_ops = []
quantized_ops = []
torchvision_ops = []
passthrough_ops = []
track_constant_ops = []
limitation_dict = {}
def main():
collect_ops()
update_file()
def collect_ops():
global prim_ops, aten_ops, quantized_ops, torchvision_ops, passthrough_ops, track_constant_ops, limitation_dict
for k, v in OPERATOR_CONVERTER_DICT.items():
if issubclass(v, PrimOperatorConverter):
prim_ops.append(k)
elif issubclass(v, NoTrackOperator):
passthrough_ops.append(k)
elif issubclass(v, TrackConstantOperator):
track_constant_ops.append(k)
else:
if v.__module__ == 'tinynn.converter.operators.torch.aten':
aten_ops.append(k)
elif v.__module__ == 'tinynn.converter.operators.torch.quantized':
quantized_ops.append(k)
elif v.__module__ == 'tinynn.converter.operators.torch.torchvision':
torchvision_ops.append(k)
else:
assert False, f"Unknown op: {k}, {v}"
source = inspect.getsource(v.parse)
regex = '^ assert ((\n|.)*?), (\'|")(.*?)(\'|")'
matches = re.findall(regex, source, flags=re.MULTILINE)
if len(matches) > 0:
limitation_dict[k] = '<br>'.join([m[-2] for m in matches])
prim_ops = sorted(prim_ops)
aten_ops = sorted(aten_ops)
quantized_ops = sorted(quantized_ops)
torchvision_ops = sorted(torchvision_ops)
passthrough_ops = sorted(passthrough_ops)
track_constant_ops = sorted(track_constant_ops)
def print_operators(topic, ops, f, desc=None, eol=True):
if len(ops) > 0:
f.write(f'## {topic}\n')
if desc is not None:
f.write(f'{desc}\n')
f.write('| Operator | Limitations |\n')
f.write('|---------------------------|--------------|\n')
for k in ops:
limitation = limitation_dict.get(k, '')
f.write(f'| `{k}` | {limitation} |\n')
if eol:
f.write('\n')
def update_file():
root_dir = os.path.dirname(CURRENT_PATH)
file_path = os.path.join(root_dir, 'docs', 'op_matrix.md')
with open(file_path, 'w', encoding='utf-8') as f:
file_header = '<!-- Generated by scripts/gen_op_docs.py. DO NOT EDIT!!! -->\n'
f.write(file_header)
header = '# Supported PyTorch Operators\n'
f.write(header)
print_operators('Primitives', prim_ops, f, 'Operators that are implemented in Python')
print_operators('ATen Operators', aten_ops, f)
print_operators('Quantized Operators', quantized_ops, f)
print_operators('TorchVision Operators', torchvision_ops, f)
print_operators(
'Passthrough Operators',
passthrough_ops,
f,
'Non-tracking operators that are ignored during translation',
)
print_operators(
'Constant Tracking Operators',
track_constant_ops,
f,
'Tracking operators that produce a dynamic constant',
False,
)
if __name__ == '__main__':
main()