tinynn/converter/utils/tflite.py (31 lines of code) (raw):
from ..schemas.tflite import schema_generated as tflite
def parse_model(path):
with open(path, 'rb') as f:
buf = f.read()
model = tflite.Model.GetRootAsModel(buf, 0)
return model
def parse_lstm_states(model):
if isinstance(model, str):
model = parse_model(model)
elif isinstance(model, bytes):
model = tflite.Model.GetRootAsModel(model, 0)
elif not isinstance(model, tflite.Model):
assert False, f"expected type str, bytes and tflite.Model but got {type(model).__name__}"
assert model.SubgraphsLength() == 1, "Only one subgraph is supported"
subgraph = model.Subgraphs(0)
state_idx = []
for i in range(subgraph.OperatorsLength()):
op = subgraph.Operators(i)
opcode = model.OperatorCodes(op.OpcodeIndex())
if opcode.DeprecatedBuiltinCode() in (
tflite.BuiltinOperator.BIDIRECTIONAL_SEQUENCE_LSTM,
tflite.BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_LSTM,
):
for j in range(op.InputsLength()):
tensor_idx = op.Inputs(j)
if tensor_idx < 0:
continue
op_input = subgraph.Tensors(tensor_idx)
if op_input.IsVariable():
state_idx.append(tensor_idx)
return state_idx