import os
import inspect
import re

from tinynn.converter.operators.torch import OPERATOR_CONVERTER_DICT
from tinynn.converter.operators.torch.base import NoTrackOperator, PrimOperatorConverter, TrackConstantOperator

CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))

prim_ops = []
aten_ops = []
quantized_ops = []
torchvision_ops = []
passthrough_ops = []
track_constant_ops = []

limitation_dict = {}


def main():
    collect_ops()
    update_file()


def collect_ops():
    global prim_ops, aten_ops, quantized_ops, torchvision_ops, passthrough_ops, track_constant_ops, limitation_dict

    for k, v in OPERATOR_CONVERTER_DICT.items():
        if issubclass(v, PrimOperatorConverter):
            prim_ops.append(k)
        elif issubclass(v, NoTrackOperator):
            passthrough_ops.append(k)
        elif issubclass(v, TrackConstantOperator):
            track_constant_ops.append(k)
        else:
            if v.__module__ == 'tinynn.converter.operators.torch.aten':
                aten_ops.append(k)
            elif v.__module__ == 'tinynn.converter.operators.torch.quantized':
                quantized_ops.append(k)
            elif v.__module__ == 'tinynn.converter.operators.torch.torchvision':
                torchvision_ops.append(k)
            else:
                assert False, f"Unknown op: {k}, {v}"

            source = inspect.getsource(v.parse)
            regex = '^        assert ((\n|.)*?), (\'|")(.*?)(\'|")'
            matches = re.findall(regex, source, flags=re.MULTILINE)
            if len(matches) > 0:
                limitation_dict[k] = '<br>'.join([m[-2] for m in matches])

    prim_ops = sorted(prim_ops)
    aten_ops = sorted(aten_ops)
    quantized_ops = sorted(quantized_ops)
    torchvision_ops = sorted(torchvision_ops)
    passthrough_ops = sorted(passthrough_ops)
    track_constant_ops = sorted(track_constant_ops)


def print_operators(topic, ops, f, desc=None, eol=True):
    if len(ops) > 0:
        f.write(f'## {topic}\n')
        if desc is not None:
            f.write(f'{desc}\n')
        f.write('| Operator                  | Limitations  |\n')
        f.write('|---------------------------|--------------|\n')

        for k in ops:
            limitation = limitation_dict.get(k, '')
            f.write(f'| `{k}` | {limitation} |\n')

        if eol:
            f.write('\n')


def update_file():
    root_dir = os.path.dirname(CURRENT_PATH)
    file_path = os.path.join(root_dir, 'docs', 'op_matrix.md')

    with open(file_path, 'w', encoding='utf-8') as f:

        file_header = '<!-- Generated by scripts/gen_op_docs.py. DO NOT EDIT!!! -->\n'
        f.write(file_header)

        header = '# Supported PyTorch Operators\n'
        f.write(header)

        print_operators('Primitives', prim_ops, f, 'Operators that are implemented in Python')
        print_operators('ATen Operators', aten_ops, f)
        print_operators('Quantized Operators', quantized_ops, f)
        print_operators('TorchVision Operators', torchvision_ops, f)
        print_operators(
            'Passthrough Operators',
            passthrough_ops,
            f,
            'Non-tracking operators that are ignored during translation',
        )
        print_operators(
            'Constant Tracking Operators',
            track_constant_ops,
            f,
            'Tracking operators that produce a dynamic constant',
            False,
        )


if __name__ == '__main__':
    main()
