in contrib/tflite2tf/tflite2tf.py [0:0]
def parse_tflite(path):
if isinstance(path, str):
model = parse_model(path)
else:
assert False, f"expected type str but got {type(path)}"
assert model.SubgraphsLength() == 1, "Only one subgraph is supported"
subgraph = model.Subgraphs(0)
input_names = []
input_shapes = []
input_signatures = []
output_file = open('generate_tf_savedmodel.py', 'w', encoding='utf-8')
output_file.write(HEADER.format(os.path.abspath(path)))
# Collect input info so that we can generate input signatures
for i in range(subgraph.InputsLength()):
inp = subgraph.Inputs(i)
tensor = subgraph.Tensors(inp)
dtype = TFLITE_TF_TYPE_MAPPING.get(tensor.Type(), None)
if dtype is None:
print('Dtype not supported:', tensor.Type())
exit(1)
input_names.append(f'tensor_{inp}')
input_shapes.append(tuple(tensor.ShapeAsNumpy().tolist()))
input_signatures.append(f"tf.TensorSpec(shape={input_shapes[-1]}, dtype={dtype})")
line = '''class TFModel(tf.Module):
def __init__(self):'''
output_file.write(f'{line}\n')
# For some ops, the layout of the weight data is different. We try to mark them here.
buffer_transform_dict = {}
for i in range(subgraph.OperatorsLength()):
op = subgraph.Operators(i)
opcode = model.OperatorCodes(op.OpcodeIndex())
if opcode.BuiltinCode() in (tflite.BuiltinOperator.CONV_2D, tflite.BuiltinOperator.DEPTHWISE_CONV_2D):
buffer_transform_dict[op.Inputs(1)] = 1
elif opcode.BuiltinCode() == tflite.BuiltinOperator.TRANSPOSE_CONV:
buffer_transform_dict[op.Inputs(1)] = 2
# Generate buffer definitions in TF
transform_tensors = set()
for i in range(subgraph.TensorsLength()):
tensor = subgraph.Tensors(i)
if tensor.Buffer() != 0:
buffer = tensor.Buffer()
dtype = TFLITE_NP_TYPE_MAPPING.get(tensor.Type(), None)
if dtype is None:
print('Dtype not supported:', tensor.Type())
exit(1)
shape = tensor.ShapeAsNumpy()
# For marked weights, certain weight transformation is performed
transform = buffer_transform_dict.get(i, None)
data_line = (
f'np.frombuffer(tfl_model.Buffers({buffer}).DataAsNumpy().tobytes(),'
f' dtype="{dtype}").reshape({shape.tolist()})'
)
if transform in (1, 2):
if transform == 1:
sequence = (1, 2, 3, 0)
else:
sequence = (1, 2, 0, 3)
data_line = f'np.transpose({data_line}, {sequence})'
elif transform is not None:
print(f'Unknown transform: {transform}')
exit(1)
line = f'self.tensor_{i} = tf.constant({data_line})'
output_file.write(f' {line}\n')
else:
transform_tensors.add(i)
output_file.write('\n')
line = f'''
@tf.function
def __call__(self, {", ".join(input_names)}):'''
output_file.write(f'{line}\n')
# Parsing operations and generating function calls in TF
for i in range(subgraph.OperatorsLength()):
op = subgraph.Operators(i)
opcode = model.OperatorCodes(op.OpcodeIndex())
code = opcode.BuiltinCode()
if OP_NAME_MAPPING[code] in OP_PARSER_DICT:
parser = OP_PARSER_DICT[OP_NAME_MAPPING[code]]
line = parser(op, transform_tensors)
output_file.write(f' {line}\n')
else:
print(f'OpCode Unknown: {code}({OP_NAME_MAPPING[code]})')
exit(1)
# Generating return values according to model outputs
output_names = []
for i in range(subgraph.OutputsLength()):
outp = subgraph.Outputs(i)
output_names.append(f'tensor_{outp}')
input_names_for_conversion = ", ".join((f'input_{i}' for i in range(len(input_names))))
if len(output_names) == 1:
output_for_conversion = '{"output_0": result}'
else:
outputs = [f'"output_{i}": result[i]' for i in range(len(output_names))]
output_for_conversion = f'{{ {", ".join(outputs)} }}'
input_signatures_for_conversion = ', '.join([f'input_{i}={val}' for i, val in enumerate(input_signatures)])
line = f'return {", ".join(output_names)}'
output_file.write(f' {line}\n')
conversion_logic = CONVERT_LOGIC.format(
input_names_for_conversion, input_names_for_conversion, output_for_conversion, input_signatures_for_conversion
)
output_file.write(f'\n{conversion_logic}\n')
output_file.close()