scripts/gen_op_docs.py (84 lines of code) (raw):

import os import inspect import re from tinynn.converter.operators.torch import OPERATOR_CONVERTER_DICT from tinynn.converter.operators.torch.base import NoTrackOperator, PrimOperatorConverter, TrackConstantOperator CURRENT_PATH = os.path.abspath(os.path.dirname(__file__)) prim_ops = [] aten_ops = [] quantized_ops = [] torchvision_ops = [] passthrough_ops = [] track_constant_ops = [] limitation_dict = {} def main(): collect_ops() update_file() def collect_ops(): global prim_ops, aten_ops, quantized_ops, torchvision_ops, passthrough_ops, track_constant_ops, limitation_dict for k, v in OPERATOR_CONVERTER_DICT.items(): if issubclass(v, PrimOperatorConverter): prim_ops.append(k) elif issubclass(v, NoTrackOperator): passthrough_ops.append(k) elif issubclass(v, TrackConstantOperator): track_constant_ops.append(k) else: if v.__module__ == 'tinynn.converter.operators.torch.aten': aten_ops.append(k) elif v.__module__ == 'tinynn.converter.operators.torch.quantized': quantized_ops.append(k) elif v.__module__ == 'tinynn.converter.operators.torch.torchvision': torchvision_ops.append(k) else: assert False, f"Unknown op: {k}, {v}" source = inspect.getsource(v.parse) regex = '^ assert ((\n|.)*?), (\'|")(.*?)(\'|")' matches = re.findall(regex, source, flags=re.MULTILINE) if len(matches) > 0: limitation_dict[k] = '<br>'.join([m[-2] for m in matches]) prim_ops = sorted(prim_ops) aten_ops = sorted(aten_ops) quantized_ops = sorted(quantized_ops) torchvision_ops = sorted(torchvision_ops) passthrough_ops = sorted(passthrough_ops) track_constant_ops = sorted(track_constant_ops) def print_operators(topic, ops, f, desc=None, eol=True): if len(ops) > 0: f.write(f'## {topic}\n') if desc is not None: f.write(f'{desc}\n') f.write('| Operator | Limitations |\n') f.write('|---------------------------|--------------|\n') for k in ops: limitation = limitation_dict.get(k, '') f.write(f'| `{k}` | {limitation} |\n') if eol: f.write('\n') def update_file(): root_dir = os.path.dirname(CURRENT_PATH) file_path = os.path.join(root_dir, 'docs', 'op_matrix.md') with open(file_path, 'w', encoding='utf-8') as f: file_header = '<!-- Generated by scripts/gen_op_docs.py. DO NOT EDIT!!! -->\n' f.write(file_header) header = '# Supported PyTorch Operators\n' f.write(header) print_operators('Primitives', prim_ops, f, 'Operators that are implemented in Python') print_operators('ATen Operators', aten_ops, f) print_operators('Quantized Operators', quantized_ops, f) print_operators('TorchVision Operators', torchvision_ops, f) print_operators( 'Passthrough Operators', passthrough_ops, f, 'Non-tracking operators that are ignored during translation', ) print_operators( 'Constant Tracking Operators', track_constant_ops, f, 'Tracking operators that produce a dynamic constant', False, ) if __name__ == '__main__': main()