contrib/tflite2tf/tflite2tf.py (276 lines of code) (raw):
from tinynn.converter.utils.tflite import parse_model
from tinynn.converter.schemas.tflite import schema_generated as tflite
import os
import typing
## Utility functions
def input_name(op: tflite.Operator, idx: int, transform_tensors: typing.Set[int]):
tensor_id = op.Inputs(idx)
if tensor_id in transform_tensors:
return f'tensor_{tensor_id}'
else:
return f'self.tensor_{tensor_id}'
def output_name(op: tflite.Operator, idx: int, transform_tensors: typing.Set[int]):
tensor_id = op.Outputs(idx)
assert tensor_id in transform_tensors
return f'tensor_{tensor_id}'
def handle_fused_act(act, data_line):
if act == tflite.ActivationFunctionType.RELU:
data_line = f'tf.nn.relu({data_line})'
elif act == tflite.ActivationFunctionType.RELU6:
data_line = f'tf.nn.relu6({data_line})'
elif act != tflite.ActivationFunctionType.NONE:
print('Fused act not supported:', act)
exit(1)
return data_line
## OP parsing functions
def parse_pad(op: tflite.Operator, transform_tensors: typing.Set[int]):
line = (
f'{output_name(op, 0, transform_tensors)} = tf.pad({input_name(op, 0, transform_tensors)},'
f' {input_name(op, 1, transform_tensors)})'
)
return line
def parse_add(op: tflite.Operator, transform_tensors: typing.Set[int]):
line = (
f'{output_name(op, 0, transform_tensors)} = tf.math.add({input_name(op, 0, transform_tensors)},'
f' {input_name(op, 1, transform_tensors)})'
)
return line
def parse_reshape(op: tflite.Operator, transform_tensors: typing.Set[int]):
line = (
f'{output_name(op, 0, transform_tensors)} = tf.reshape({input_name(op, 0, transform_tensors)},'
f' {input_name(op, 1, transform_tensors)})'
)
return line
def parse_resize(op: tflite.Operator, transform_tensors: typing.Set[int]):
line = (
f'{output_name(op, 0, transform_tensors)} = tf.image.resize({input_name(op, 0, transform_tensors)},'
f' {input_name(op, 1, transform_tensors)})'
)
return line
def parse_relu(op: tflite.Operator, transform_tensors: typing.Set[int]):
line = f'{output_name(op, 0, transform_tensors)} = tf.nn.relu({input_name(op, 0, transform_tensors)})'
return line
def parse_depth2space(op: tflite.Operator, transform_tensors: typing.Set[int]):
assert op.BuiltinOptionsType() == tflite.BuiltinOptions.DepthToSpaceOptions
op_opt = op.BuiltinOptions()
opt = tflite.DepthToSpaceOptions()
opt.Init(op_opt.Bytes, op_opt.Pos)
upsample_scale = opt.BlockSize()
line = (
f'{output_name(op, 0, transform_tensors)} = tf.nn.depth_to_space({input_name(op, 0, transform_tensors)},'
f' {upsample_scale})'
)
return line
def parse_conv2d(op: tflite.Operator, transform_tensors: typing.Set[int]):
assert op.BuiltinOptionsType() == tflite.BuiltinOptions.Conv2DOptions
op_opt = op.BuiltinOptions()
opt = tflite.Conv2DOptions()
opt.Init(op_opt.Bytes, op_opt.Pos)
strides = [opt.StrideH(), opt.StrideW()]
padding = "SAME" if opt.Padding() == tflite.Padding.SAME else "VALID"
dilations = [opt.DilationHFactor(), opt.DilationWFactor()]
data_line = (
f'tf.nn.bias_add(tf.nn.conv2d({input_name(op, 0, transform_tensors)},'
f' {input_name(op, 1, transform_tensors)}, strides={strides}, padding="{padding}", dilations={dilations}),'
f' {input_name(op, 2, transform_tensors)})'
)
data_line = handle_fused_act(opt.FusedActivationFunction(), data_line)
line = f'{output_name(op, 0, transform_tensors)} = {data_line}'
return line
def parse_depthwiseconv2d(op: tflite.Operator, transform_tensors: typing.Set[int]):
assert op.BuiltinOptionsType() == tflite.BuiltinOptions.DepthwiseConv2DOptions
op_opt = op.BuiltinOptions()
opt = tflite.DepthwiseConv2DOptions()
opt.Init(op_opt.Bytes, op_opt.Pos)
strides = [1, opt.StrideH(), opt.StrideW(), 1]
padding = "SAME" if opt.Padding() == tflite.Padding.SAME else "VALID"
dilations = [opt.DilationHFactor(), opt.DilationWFactor()]
data_line = (
f'tf.nn.bias_add(tf.nn.depthwise_conv2d({input_name(op, 0, transform_tensors)},'
f' {input_name(op, 1, transform_tensors)}, strides={strides}, padding="{padding}", dilations={dilations}),'
f' {input_name(op, 2, transform_tensors)})'
)
data_line = handle_fused_act(opt.FusedActivationFunction(), data_line)
line = f'{output_name(op, 0, transform_tensors)} = {data_line}'
return line
def parse_transposeconv2d(op: tflite.Operator, transform_tensors: typing.Set[int]):
assert op.BuiltinOptionsType() == tflite.BuiltinOptions.TransposeConvOptions
op_opt = op.BuiltinOptions()
opt = tflite.TransposeConvOptions()
opt.Init(op_opt.Bytes, op_opt.Pos)
strides = [opt.StrideH(), opt.StrideW()]
padding = "SAME" if opt.Padding() == tflite.Padding.SAME else "VALID"
dilations = None
data_line = (
f'tf.nn.bias_add(tf.nn.conv2d_transpose({input_name(op, 2, transform_tensors)},'
f' {input_name(op, 1, transform_tensors)}, output_shape={input_name(op, 0, transform_tensors)},'
f' strides={strides}, padding="{padding}", dilations={dilations}), {input_name(op, 3, transform_tensors)})'
)
data_line = handle_fused_act(opt.FusedActivationFunction(), data_line)
line = f'{output_name(op, 0, transform_tensors)} = {data_line}'
return line
def parse_averagepool2d(op: tflite.Operator, transform_tensors: typing.Set[int]):
assert op.BuiltinOptionsType() == tflite.BuiltinOptions.Pool2DOptions
op_opt = op.BuiltinOptions()
opt = tflite.Pool2DOptions()
opt.Init(op_opt.Bytes, op_opt.Pos)
strides = [opt.StrideH(), opt.StrideW()]
padding = "SAME" if opt.Padding() == tflite.Padding.SAME else "VALID"
ksize = [opt.FilterHeight(), opt.FilterWidth()]
line = (
f'{output_name(op, 0, transform_tensors)} = tf.nn.avg_pool2d({input_name(op, 0, transform_tensors)},'
f' strides={strides}, padding="{padding}", ksize={ksize})'
)
return line
def parse_fullyconnected(op: tflite.Operator, transform_tensors: typing.Set[int]):
assert op.BuiltinOptionsType() == tflite.BuiltinOptions.FullyConnectedOptions
op_opt = op.BuiltinOptions()
opt = tflite.FullyConnectedOptions()
opt.Init(op_opt.Bytes, op_opt.Pos)
data_line = (
f'tf.nn.bias_add(tf.matmul({input_name(op, 0, transform_tensors)}, {input_name(op, 1, transform_tensors)},'
f' transpose_b=True), {input_name(op, 2, transform_tensors)})'
)
data_line = handle_fused_act(opt.FusedActivationFunction(), data_line)
line = f'{output_name(op, 0, transform_tensors)} = {data_line}'
return line
def parse_slice(op: tflite.Operator, transform_tensors: typing.Set[int]):
assert op.BuiltinOptionsType() == tflite.BuiltinOptions.SliceOptions
op_opt = op.BuiltinOptions()
opt = tflite.SliceOptions()
opt.Init(op_opt.Bytes, op_opt.Pos)
line = (
f'{output_name(op, 0, transform_tensors)} = tf.slice({input_name(op, 0, transform_tensors)},'
f' {input_name(op, 1, transform_tensors)}, {input_name(op, 2, transform_tensors)})'
)
return line
# OP parser registration table
OP_PARSER_DICT = {
'PAD': parse_pad,
'ADD': parse_add,
'RESIZE_BILINEAR': parse_resize,
'RESHAPE': parse_reshape,
'CONV_2D': parse_conv2d,
'AVERAGE_POOL_2D': parse_averagepool2d,
'DEPTHWISE_CONV_2D': parse_depthwiseconv2d,
'FULLY_CONNECTED': parse_fullyconnected,
'DEPTH_TO_SPACE': parse_depth2space,
'TRANSPOSE_CONV': parse_transposeconv2d,
'SLICE': parse_slice,
'RELU': parse_relu,
}
# Header for the generated script
HEADER = """from tinynn.converter.utils.tflite import parse_model
import tensorflow as tf
import numpy as np
tfl_model = parse_model("{}")
"""
# Conversion logic for the generated script
CONVERT_LOGIC = """
tf_model = TFModel()
@tf.function
def wrapper({}):
result = tf_model({})
outputs = {}
return outputs
tf.saved_model.save(tf_model,
"saved_model",
signatures=wrapper.get_concrete_function(
{}
),
)
"""
## Constants for TF and TFLite mapping
# mapping between op code and op name
OP_NAME_MAPPING = {
getattr(tflite.BuiltinOperator, k): k
for k in dir(tflite.BuiltinOperator)
if not k.startswith('__') and not k.endswith('__')
}
TFLITE_NP_TYPE_MAPPING = {
tflite.TensorType.FLOAT32: "float32",
tflite.TensorType.INT32: "int32",
tflite.TensorType.BOOL: "bool",
tflite.TensorType.FLOAT64: "float64",
tflite.TensorType.INT64: "int64",
}
TFLITE_TF_TYPE_MAPPING = {
tflite.TensorType.FLOAT32: "tf.float32",
tflite.TensorType.INT32: "tf.int32",
tflite.TensorType.BOOL: "tf.bool",
tflite.TensorType.FLOAT64: "tf.float64",
tflite.TensorType.INT64: "tf.int64",
}
## Main logic
def parse_tflite(path):
if isinstance(path, str):
model = parse_model(path)
else:
assert False, f"expected type str but got {type(path)}"
assert model.SubgraphsLength() == 1, "Only one subgraph is supported"
subgraph = model.Subgraphs(0)
input_names = []
input_shapes = []
input_signatures = []
output_file = open('generate_tf_savedmodel.py', 'w', encoding='utf-8')
output_file.write(HEADER.format(os.path.abspath(path)))
# Collect input info so that we can generate input signatures
for i in range(subgraph.InputsLength()):
inp = subgraph.Inputs(i)
tensor = subgraph.Tensors(inp)
dtype = TFLITE_TF_TYPE_MAPPING.get(tensor.Type(), None)
if dtype is None:
print('Dtype not supported:', tensor.Type())
exit(1)
input_names.append(f'tensor_{inp}')
input_shapes.append(tuple(tensor.ShapeAsNumpy().tolist()))
input_signatures.append(f"tf.TensorSpec(shape={input_shapes[-1]}, dtype={dtype})")
line = '''class TFModel(tf.Module):
def __init__(self):'''
output_file.write(f'{line}\n')
# For some ops, the layout of the weight data is different. We try to mark them here.
buffer_transform_dict = {}
for i in range(subgraph.OperatorsLength()):
op = subgraph.Operators(i)
opcode = model.OperatorCodes(op.OpcodeIndex())
if opcode.BuiltinCode() in (tflite.BuiltinOperator.CONV_2D, tflite.BuiltinOperator.DEPTHWISE_CONV_2D):
buffer_transform_dict[op.Inputs(1)] = 1
elif opcode.BuiltinCode() == tflite.BuiltinOperator.TRANSPOSE_CONV:
buffer_transform_dict[op.Inputs(1)] = 2
# Generate buffer definitions in TF
transform_tensors = set()
for i in range(subgraph.TensorsLength()):
tensor = subgraph.Tensors(i)
if tensor.Buffer() != 0:
buffer = tensor.Buffer()
dtype = TFLITE_NP_TYPE_MAPPING.get(tensor.Type(), None)
if dtype is None:
print('Dtype not supported:', tensor.Type())
exit(1)
shape = tensor.ShapeAsNumpy()
# For marked weights, certain weight transformation is performed
transform = buffer_transform_dict.get(i, None)
data_line = (
f'np.frombuffer(tfl_model.Buffers({buffer}).DataAsNumpy().tobytes(),'
f' dtype="{dtype}").reshape({shape.tolist()})'
)
if transform in (1, 2):
if transform == 1:
sequence = (1, 2, 3, 0)
else:
sequence = (1, 2, 0, 3)
data_line = f'np.transpose({data_line}, {sequence})'
elif transform is not None:
print(f'Unknown transform: {transform}')
exit(1)
line = f'self.tensor_{i} = tf.constant({data_line})'
output_file.write(f' {line}\n')
else:
transform_tensors.add(i)
output_file.write('\n')
line = f'''
@tf.function
def __call__(self, {", ".join(input_names)}):'''
output_file.write(f'{line}\n')
# Parsing operations and generating function calls in TF
for i in range(subgraph.OperatorsLength()):
op = subgraph.Operators(i)
opcode = model.OperatorCodes(op.OpcodeIndex())
code = opcode.BuiltinCode()
if OP_NAME_MAPPING[code] in OP_PARSER_DICT:
parser = OP_PARSER_DICT[OP_NAME_MAPPING[code]]
line = parser(op, transform_tensors)
output_file.write(f' {line}\n')
else:
print(f'OpCode Unknown: {code}({OP_NAME_MAPPING[code]})')
exit(1)
# Generating return values according to model outputs
output_names = []
for i in range(subgraph.OutputsLength()):
outp = subgraph.Outputs(i)
output_names.append(f'tensor_{outp}')
input_names_for_conversion = ", ".join((f'input_{i}' for i in range(len(input_names))))
if len(output_names) == 1:
output_for_conversion = '{"output_0": result}'
else:
outputs = [f'"output_{i}": result[i]' for i in range(len(output_names))]
output_for_conversion = f'{{ {", ".join(outputs)} }}'
input_signatures_for_conversion = ', '.join([f'input_{i}={val}' for i, val in enumerate(input_signatures)])
line = f'return {", ".join(output_names)}'
output_file.write(f' {line}\n')
conversion_logic = CONVERT_LOGIC.format(
input_names_for_conversion, input_names_for_conversion, output_for_conversion, input_signatures_for_conversion
)
output_file.write(f'\n{conversion_logic}\n')
output_file.close()
if __name__ == '__main__':
parse_tflite('/workspaces/TinyNeuralNetwork/examples/converter/out/mbv1_224.tflite')