import os

from tinynn.graph.quantization.quantizer import (
    FUNCTIONAL_MODULE_MAPPING,
    FUSE_FALLBACK_DICT,
    FUSE_RULE_LIST,
    FUSE_RULE_LIST_EXTRA,
    FUSE_RULE_LIST_PTQ_ONLY,
    KNOWN_QSTATS,
    Q_MODULES_MAPPING,
    REWRITE_QUANTIZABLE_RULE_LIST,
    REWRITE_TO_FUSE_RULE_LIST,
    UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST,
)
from tinynn.graph.tracer import qualified_name

CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))


def prepare_unsupported_operators(lines):
    lines.append('## Unsupported operators in PyTorch for static quantization\n')
    lines.append(
        'Quantized OPs that are natively not supported by PyTorch (and possibly TFLite). But some of them can be'
        ' translated to quantized TFLite through extra configuration.\n'
    )

    lines.append('| Operator                  | Minimum Supported PyTorch Version  |\n')
    lines.append('|---------------------------|------------------------------------|\n')
    unsupported_dict = UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.copy()
    unsupported_dict.update({k: None for k in Q_MODULES_MAPPING})
    transformed_ops = {qualified_name(k, short=True): (v or '/') for k, v in unsupported_dict.items()}
    sorted_ops = {k: transformed_ops[k] for k in sorted(transformed_ops)}
    for k, v in sorted_ops.items():
        lines.append(f'| `{k}` | {v} |\n')


def prepare_rewrite_quantizable_operators(lines):
    lines.append('## Extra flags for translating the above ops to quantized TFLite\n')

    lines.append('| Operators                  | Notes  |\n')
    lines.append('|----------------------------|--------|\n')
    functional_ops = set(
        ((k,) for k in UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.keys() & FUNCTIONAL_MODULE_MAPPING.keys())
    )
    full_dict = (
        set(((k,) for k in KNOWN_QSTATS))
        | REWRITE_QUANTIZABLE_RULE_LIST
        | set(((k,) for k in Q_MODULES_MAPPING))
        | functional_ops
    )

    transformed_ops = {}
    for k in full_dict:
        notes = []
        if len(k) == 1 and k[0] in KNOWN_QSTATS:
            notes.append('For QATQuantizer/PostQuantizer, set `config={"set_quantizable_op_stats": True}`')
        if len(k) != 1 or (k[0] not in Q_MODULES_MAPPING and k[0] not in FUNCTIONAL_MODULE_MAPPING):
            notes.append('For TFLiteConverter, set `rewrite_quantizable=True`')
        if len(notes) == 0:
            notes.append('No action needed')
        if len(k) == 1:
            new_k = qualified_name(k[0], short=True)
        else:
            ks = ', '.join([qualified_name(x, short=True) for x in k])
            new_k = f'{{{ks}}}'
        transformed_ops[new_k] = '<br>'.join(notes)

    sorted_ops = {k: transformed_ops[k] for k in sorted(transformed_ops)}
    for k, v in sorted_ops.items():
        lines.append(f'| `{k}` | {v} |\n')


def prepare_fusion_rules(lines):
    lines.append('## Supported fusion rules for static quantization\n')

    lines.append('| Operators                  | Notes  |\n')
    lines.append('|----------------------------|--------|\n')

    full_dict = {k: None for k in set((*FUSE_RULE_LIST, *FUSE_RULE_LIST_EXTRA, *REWRITE_TO_FUSE_RULE_LIST))}
    for r, v in FUSE_RULE_LIST_PTQ_ONLY.items():
        s = full_dict.get(r, False)
        if v is None:
            if s is None:
                assert False, "Should not happen"
            else:
                full_dict[r] = 'PTQ only.'
        else:
            if s is None:
                full_dict[r] = f'for PTQ, only PyTorch {v}+ is supported'
            else:
                full_dict[r] = f'PTQ only. Only PyTorch {v}+ is supported'

    transformed_ops = {}
    for k, v in full_dict.items():
        if v is None:
            v = ''
        names = [qualified_name(x, short=True) for x in k]
        mapped_names = [FUSE_FALLBACK_DICT.get(n, n) for n in names]
        mods = ', '.join(mapped_names)
        transformed_ops[mods] = v

    sorted_ops = {k: transformed_ops[k] for k in sorted(transformed_ops)}
    for k, v in sorted_ops.items():
        lines.append(f'| `{{{k}}}` | {v} |\n')


def prepare_quantization_support(lines):
    lines.append('# TinyNN Quantization Support\n')

    prepare_unsupported_operators(lines)
    prepare_rewrite_quantizable_operators(lines)
    prepare_fusion_rules(lines)


def main():
    root_dir = os.path.dirname(CURRENT_PATH)
    file_path = os.path.join(root_dir, 'docs', 'quantization_support.md')

    lines = ['<!-- Generated by scripts/gen_quantized_docs.py. DO NOT EDIT!!! -->\n']

    prepare_quantization_support(lines)

    with open(file_path, 'w', encoding='utf-8') as f:
        f.writelines(lines)


if __name__ == '__main__':
    main()
