def parse_tflite()

in contrib/tflite2tf/tflite2tf.py [0:0]


def parse_tflite(path):
    if isinstance(path, str):
        model = parse_model(path)
    else:
        assert False, f"expected type str but got {type(path)}"

    assert model.SubgraphsLength() == 1, "Only one subgraph is supported"

    subgraph = model.Subgraphs(0)

    input_names = []
    input_shapes = []
    input_signatures = []

    output_file = open('generate_tf_savedmodel.py', 'w', encoding='utf-8')
    output_file.write(HEADER.format(os.path.abspath(path)))

    # Collect input info so that we can generate input signatures
    for i in range(subgraph.InputsLength()):
        inp = subgraph.Inputs(i)
        tensor = subgraph.Tensors(inp)

        dtype = TFLITE_TF_TYPE_MAPPING.get(tensor.Type(), None)
        if dtype is None:
            print('Dtype not supported:', tensor.Type())
            exit(1)

        input_names.append(f'tensor_{inp}')
        input_shapes.append(tuple(tensor.ShapeAsNumpy().tolist()))
        input_signatures.append(f"tf.TensorSpec(shape={input_shapes[-1]}, dtype={dtype})")

    line = '''class TFModel(tf.Module):
    def __init__(self):'''
    output_file.write(f'{line}\n')

    # For some ops, the layout of the weight data is different. We try to mark them here.
    buffer_transform_dict = {}
    for i in range(subgraph.OperatorsLength()):
        op = subgraph.Operators(i)

        opcode = model.OperatorCodes(op.OpcodeIndex())
        if opcode.BuiltinCode() in (tflite.BuiltinOperator.CONV_2D, tflite.BuiltinOperator.DEPTHWISE_CONV_2D):
            buffer_transform_dict[op.Inputs(1)] = 1
        elif opcode.BuiltinCode() == tflite.BuiltinOperator.TRANSPOSE_CONV:
            buffer_transform_dict[op.Inputs(1)] = 2

    # Generate buffer definitions in TF
    transform_tensors = set()
    for i in range(subgraph.TensorsLength()):
        tensor = subgraph.Tensors(i)
        if tensor.Buffer() != 0:
            buffer = tensor.Buffer()

            dtype = TFLITE_NP_TYPE_MAPPING.get(tensor.Type(), None)
            if dtype is None:
                print('Dtype not supported:', tensor.Type())
                exit(1)

            shape = tensor.ShapeAsNumpy()

            # For marked weights, certain weight transformation is performed
            transform = buffer_transform_dict.get(i, None)
            data_line = (
                f'np.frombuffer(tfl_model.Buffers({buffer}).DataAsNumpy().tobytes(),'
                f' dtype="{dtype}").reshape({shape.tolist()})'
            )
            if transform in (1, 2):
                if transform == 1:
                    sequence = (1, 2, 3, 0)
                else:
                    sequence = (1, 2, 0, 3)
                data_line = f'np.transpose({data_line}, {sequence})'
            elif transform is not None:
                print(f'Unknown transform: {transform}')
                exit(1)

            line = f'self.tensor_{i} = tf.constant({data_line})'
            output_file.write(f'        {line}\n')
        else:
            transform_tensors.add(i)

    output_file.write('\n')

    line = f'''
    @tf.function
    def __call__(self, {", ".join(input_names)}):'''
    output_file.write(f'{line}\n')

    # Parsing operations and generating function calls in TF
    for i in range(subgraph.OperatorsLength()):
        op = subgraph.Operators(i)

        opcode = model.OperatorCodes(op.OpcodeIndex())
        code = opcode.BuiltinCode()

        if OP_NAME_MAPPING[code] in OP_PARSER_DICT:
            parser = OP_PARSER_DICT[OP_NAME_MAPPING[code]]
            line = parser(op, transform_tensors)
            output_file.write(f'        {line}\n')
        else:
            print(f'OpCode Unknown: {code}({OP_NAME_MAPPING[code]})')
            exit(1)

    # Generating return values according to model outputs
    output_names = []
    for i in range(subgraph.OutputsLength()):
        outp = subgraph.Outputs(i)

        output_names.append(f'tensor_{outp}')

    input_names_for_conversion = ", ".join((f'input_{i}' for i in range(len(input_names))))
    if len(output_names) == 1:
        output_for_conversion = '{"output_0": result}'
    else:
        outputs = [f'"output_{i}": result[i]' for i in range(len(output_names))]
        output_for_conversion = f'{{ {", ".join(outputs)} }}'
    input_signatures_for_conversion = ', '.join([f'input_{i}={val}' for i, val in enumerate(input_signatures)])

    line = f'return {", ".join(output_names)}'
    output_file.write(f'        {line}\n')
    conversion_logic = CONVERT_LOGIC.format(
        input_names_for_conversion, input_names_for_conversion, output_for_conversion, input_signatures_for_conversion
    )
    output_file.write(f'\n{conversion_logic}\n')

    output_file.close()