tinynn/converter/operators/graph.py (472 lines of code) (raw):
import os
import queue
import typing
import warnings
import flatbuffers
import igraph as ig
from . import tflite as tfl
from .base import ExtendedOperator
from tinynn.util.util import get_logger
log = get_logger(__name__, 'INFO')
class CommonGraph(object):
graph: ig.Graph
tensor_map: typing.Dict[str, tfl.Tensor]
tensor_node_map: typing.Dict[str, str]
iterable_map: typing.Dict[str, typing.List[str]]
inputs: typing.List[str]
outputs: typing.List[str]
input_transpose: typing.List[bool]
output_transpose: typing.Union[typing.List[typing.Optional[bool]], typing.Optional[bool]]
node_op_counter: int
def __init__(self) -> None:
self.graph = ig.Graph(directed=True)
self.tensor_map = dict()
self.tensor_node_map = dict()
self.iterable_map = dict()
self.inputs = []
self.outputs = []
self.input_transpose = []
self.output_transpose = None
self.node_op_counter = 0
self.q_mapping = {}
self.rev_q_mapping = {}
self.transform_store = {}
self.constant_mapping = {}
def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str):
self.transform_store.setdefault(tensor_name, {})
self.transform_store[tensor_name][transform_name] = new_tensor_name
def get_transform_store(self, tensor_name: str, transform_name: str) -> typing.Optional[tfl.Tensor]:
if tensor_name not in self.transform_store:
return None
return self.transform_store[tensor_name].get(transform_name, None)
def add_iterable_pair(
self, input_names: typing.List[str], output_names: typing.List[str], key: typing.Optional[str] = None
):
"""Adds the tensor mapping for a ListConstruct tensor
Args:
input_names (typing.List[str]): The names of the input tensors
output_names (typing.List[str]): The names of the output tensors
key (typing.Literal['input', 'output'], optional): Which side is used as key. Defaults to None.
"""
if key == 'input' or len(input_names) == 1 and len(output_names) > 1:
list_name = input_names[0]
self.iterable_map.setdefault(list_name, [])
self.iterable_map[list_name].extend(output_names)
elif key == 'output' or len(input_names) > 1 and len(output_names) == 1:
list_name = output_names[0]
self.iterable_map.setdefault(list_name, [])
self.iterable_map[list_name].extend(input_names)
else:
assert False, "You should specify key == 'input' or 'output'"
def has_nested_names(self, key: str) -> bool:
"""Whether a tensor has nested tensor (names)
Args:
key (str): The name of the tensor
Returns:
bool: Whether it is a ListConstruct tensor
"""
return key in self.iterable_map
def get_list_expanded_names(self, key: str) -> typing.List[str]:
"""Get the names of the nested tensors of a ListConstruct tensor
Args:
key (str): The name of the ListConstruct tensor
Returns:
typing.List[str]: The names of the nested tensors
"""
return self.iterable_map[key]
def check_tensor(self, name: str, node_type: ExtendedOperator, tensor: tfl.Tensor) -> ig.Vertex:
"""Checks whether the node with the tensor as the output already exists
Args:
name (str): The name of the tensor
node_type (ExtendedOperator): The type of the node
tensor (tfl.Tensor): The tensor
Returns:
ig.Vertex: The node that produces the tensor
"""
node_name = self.tensor_node_map[name]
node = self.graph.vs.find(name=node_name)
assert name in self.tensor_map, f"tensor {name} is in nodes map, but not in tensors map"
# assert node["node_type"] == node_type, f"tensor {name} already exists, but with a different type"
assert id(self.tensor_map[name]) == id(tensor), f"tensor {name} already exists"
return node
def add_nodes(
self, tensors: typing.List[tfl.Tensor], node_type=ExtendedOperator.CONSTANT_NODE
) -> typing.List[ig.Vertex]:
"""Add a list of nodes (usually special ones) with the tensors
Args:
tensors (typing.List[tfl.Tensor]): The output tensors of the nodes
node_type ([type], optional): The type of the node. Defaults to ExtendedOperator.CONSTANT_NODE.
Returns:
ig.Vertex: The newly-created nodes
"""
nodes = []
for t in tensors:
if node_type in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
tensor_name = t.name + '_output'
if tensor_name in self.tensor_map:
i = 1
while True:
tensor_name = f'{t.name}_output_{i}'
if tensor_name in self.tensor_map:
i += 1
else:
break
else:
tensor_name = t.name
if tensor_name in self.tensor_node_map:
nodes.append(self.check_tensor(tensor_name, node_type, t))
else:
node = self.graph.add_vertex(
node_type=node_type,
outputs=[tensor_name],
label=ExtendedOperator(node_type).type_name(),
name=tensor_name,
)
self.tensor_map[tensor_name] = t
self.tensor_node_map[tensor_name] = node['name']
nodes.append(node)
return nodes
def add_node(self, tensors: typing.List[tfl.Tensor], tfl_op: tfl.BaseOperator, output_exists: bool = False):
"""Add a node (usually a op node) with the output tensors
Args:
tensors (typing.List[tfl.Tensor]): The output tensors of the node
tfl_op (tfl.BaseOperator): The op to be added
output_exists (bool, optional): Whether the output may already exists. Defaults to False.
Returns:
[type]: [description]
"""
output_names = [t.name for t in tfl_op.outputs]
node_unique_name = f'__tinynn_op_{self.node_op_counter}__'
self.node_op_counter += 1
if tfl_op.op.custom_code is not None:
node = self.graph.add_vertex(
node_type=tfl_op.op.code,
custom_type=tfl_op.op.custom_code,
outputs=output_names,
op=tfl_op,
label=tfl_op.type_name(),
name=node_unique_name,
)
else:
node = self.graph.add_vertex(
node_type=tfl_op.op.code,
outputs=output_names,
op=tfl_op,
label=tfl_op.type_name(),
name=node_unique_name,
)
log.debug(f'NEW VERTEX: {node["op"].type_name()}[{node["name"]}] {node["op"].inputs} -> {node["op"].outputs}')
for t in tensors:
if not output_exists:
assert (
t.name not in self.tensor_node_map
), f"output tensor ({t.name}) should not be in the nodes map at this time"
self.tensor_map[t.name] = t
else:
if t.name in self.tensor_map:
assert (
self.tensor_map[t.name] == t
), f"output tensor ({t.name}) has changed during graph reconstruction"
else:
log.debug(f'tensor node map add {t.name} during transformation')
self.tensor_map[t.name] = t
self.tensor_node_map[t.name] = node['name']
return node
def add_outputs(self, names: typing.List[str], node_type=ExtendedOperator.OUTPUT_NODE):
"""Add the output nodes with the names given
Args:
names (typing.List[str]): The names of the output nodes to be created
"""
if len(names) > 0:
output_tensors = list(map(lambda x: self.tensor_map[x], names))
output_nodes = self.add_nodes(output_tensors, node_type)
for idx, (name, output_node) in enumerate(zip(names, output_nodes)):
current_node = self.graph.vs.find(name=self.tensor_node_map[name])
edge = self.graph.add_edge(current_node, output_node, name=output_node["outputs"][0], label=name)
log.debug(
f'NEW EDGE: {current_node["label"]} -> {output_node["label"]} {self.tensor_map[edge["name"]]}'
)
def add_operator(self, tfl_op: tfl.BaseOperator, transform: bool = False):
"""Add a new operator to the graph
Args:
tfl_op (tfl.BaseOperator): The operator be added
transform (bool, optional): Whether it is created by a transformable node. Defaults to False.
"""
input_nodes = self.add_nodes(tfl_op.inputs)
current_node = self.add_node(tfl_op.outputs, tfl_op, transform)
for idx, input_node in enumerate(input_nodes):
edge = self.graph.add_edge(
input_node, current_node, name=tfl_op.inputs[idx].name, label=tfl_op.inputs[idx].name
)
log.debug(f'NEW EDGE: {input_node["label"]} -> {current_node["label"]} {self.tensor_map[edge["name"]]}')
output_names = set(self.outputs).intersection(set([t.name for t in tfl_op.outputs]))
self.add_outputs(output_names)
def try_restore_edges(self, mapping: typing.List[typing.Tuple[str, str]]):
"""Try to restore the edges between nodes
Args:
mapping (typing.List[typing.Tuple[str, str]]): A list of mapping (edge name, target node nam)
"""
for edge_name, node_name in mapping:
cand = self.graph.vs.select(name=node_name)
# Only restore when the node exists
if cand:
next_node = cand[0]
prev_node = self.graph.vs.find(name=self.tensor_node_map[edge_name])
edge = self.graph.add_edge(prev_node, next_node, name=edge_name, label=edge_name)
log.debug(f'NEW EDGE: {prev_node["label"]} -> {next_node["label"]} {self.tensor_map[edge["name"]]}')
def remove_operator_input(
self, node: ig.Vertex, input_idx: int, return_ids: bool = False, skip: int = 0
) -> typing.Optional[typing.List[int]]:
"""Remove an input tensor in a op node
Args:
node (ig.Vertex): An op node
input_idx (int): the index of the input tensor
return_ids (bool): Return the ids instead of removing the edges. Defaults to False.
skip (int): Number of items to skip
Returns:
typing.Optional[typing.List[int]]: The edges to be removed if return_ids is True, otherwise None
"""
old_tensor = node['op'].inputs[input_idx]
assert old_tensor.name in self.tensor_map
remove_edges = []
for edge in node.in_edges():
start = self.graph.vs[edge.source]
for i in range(len(start['outputs'])):
if start['outputs'][i] == old_tensor.name and edge['name'] == old_tensor.name:
if skip > 0:
skip -= 1
continue
remove_edges.append(edge.index)
break
if len(remove_edges) > 0:
break
if return_ids:
return remove_edges
else:
self.graph.delete_edges(remove_edges)
def replace_operator_input(
self, node: ig.Vertex, input_idx: int, new_tensor: tfl.Tensor, return_ids: bool = False, skip: int = 0
) -> typing.Optional[typing.List[int]]:
"""Use a new input tensor in a op node
Args:
node (ig.Vertex): An op node
input_idx (int): the index of the input tensor
new_tensor (tfl.Tensor): The tensor to be be used
return_ids (bool): Return the ids instead of removing the edges. Defaults to False.
skip (int): Number of items to skip
Returns:
typing.Optional[typing.List[int]]: The edges to be removed if return_ids is True, otherwise None
"""
remove_edges = self.remove_operator_input(node, input_idx, return_ids=True, skip=skip)
node['op'].inputs[input_idx] = new_tensor
new_node = self.add_nodes([new_tensor])[0]
edge = self.graph.add_edge(new_node, node, name=new_tensor.name, label=new_tensor.name)
log.debug(f'NEW EDGE: {new_node["label"]} -> {node["label"]} {self.tensor_map[edge["name"]]}')
if return_ids:
return remove_edges
else:
self.graph.delete_edges(remove_edges)
def append_operator_input(self, node: ig.Vertex, new_tensor: tfl.Tensor, as_intermediate: bool = False):
"""Add a new input tensor to a op node
Args:
node (ig.Vertex): An op node
new_tensor (tfl.Tensor): The tensor to be added
"""
if as_intermediate:
node['op'].intermediates.append(new_tensor)
else:
node['op'].inputs.append(new_tensor)
new_node = self.add_nodes([new_tensor])[0]
edge = self.graph.add_edge(new_node, node, name=new_tensor.name, label=new_tensor.name)
log.debug(f'NEW EDGE: {new_node["label"]} -> {node["label"]} {self.tensor_map[edge["name"]]}')
def remove_operator(self, tfl_op: tfl.BaseOperator):
tensor_edge = self.graph.es.find(name=tfl_op.outputs[0].name)
op_node = tensor_edge.source
self.graph.delete_vertices([op_node.index])
def remove_operators(self, tfl_ops: typing.List['tfl.BaseOperator']):
indices = []
for tfl_op in tfl_ops:
tensor_edge = self.graph.es.find(name=tfl_op.outputs[0].name)
op_node = tensor_edge.source
indices.append(op_node.index)
self.graph.delete_vertices(indices)
def connect_next_tensors(
self,
find_node: ig.Vertex,
connect_node: ig.Vertex,
tensor_name: str,
skips_nodes: typing.Optional[typing.List[str]] = None,
):
"""Add edges between `connect_node` and the next nodes of `find_node` with the name `tensor_name`
Args:
find_node ([ig.Vertex]): The node to search for next nodes
connect_node ([ig.Vertex]): The node to connect the next nodes with
tensor_name ([str]): The name of the edge (tensor)
skip_nodes ([typing.Optional[typing.List[str]]]): The name of the next nodes to skip
"""
for next_tensor in find_node.out_edges():
next_op = self.graph.vs[next_tensor.target]
if skips_nodes is not None and next_op['name'] in skips_nodes:
continue
if next_op['node_type'] not in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
assert (
tensor_name == next_tensor['name']
), f'next tensor name mismatches: {tensor_name} vs {next_tensor["name"]}'
self.graph.add_edge(connect_node, next_op, name=tensor_name, label=tensor_name)
else:
assert next_tensor['name'].startswith(
tensor_name + '_output'
), f'output tensor and node name mismatches: {tensor_name} vs {next_tensor["name"]}'
self.graph.add_edge(connect_node, next_op, name=next_tensor['name'], label=tensor_name)
log.debug(f'NEW EDGE: {connect_node["label"]} -> {next_op["label"]} {self.tensor_map[next_tensor["name"]]}')
def replace_next_tensors(
self,
find_node: ig.Vertex,
connect_node: ig.Vertex,
tensor_name: str,
skips_nodes: typing.Optional[typing.List[str]] = None,
):
"""A variant of connect_next_tensors that also replace the tensors in the next nodes
Args:
find_node ([ig.Vertex]): The node to search for next nodes
connect_node ([ig.Vertex]): The node to connect the next nodes with
tensor_name ([str]): The name of the edge (tensor)
skip_nodes ([typing.Optional[typing.List[str]]]): The name of the next nodes to skip
"""
orig_name = find_node['outputs'][0]
for next_tensor in find_node.out_edges():
next_op = self.graph.vs[next_tensor.target]
if skips_nodes is not None and next_op['name'] in skips_nodes:
continue
if next_op['node_type'] != ExtendedOperator.OUTPUT_NODE:
assert (
orig_name == next_tensor['name']
), f'next tensor name mismatches: {tensor_name} vs {next_tensor["name"]}'
op = next_op['op']
for idx, t in enumerate(op.inputs):
if t.name == orig_name:
op.inputs[idx] = self.tensor_map[tensor_name]
self.graph.add_edge(connect_node, next_op, name=tensor_name, label=tensor_name)
else:
assert False, 'replace_next_tensors where last_node.next is an output node is not supported'
log.debug(f'NEW EDGE: {connect_node["label"]} -> {next_op["label"]} {self.tensor_map[next_tensor["name"]]}')
log.debug(f'{next_op["label"]} {next_op["op"].inputs} {next_op["op"].outputs}')
def visualize(self, hide_constants=True):
"""Plot the TinyNeuralNetwork graph
Args:
hide_constants (bool, optional): Hide constants in the plot. Defaults to True.
"""
self.check()
import matplotlib.pyplot as plt
_, axs = plt.subplots()
if hide_constants:
nodes = self.graph.vs.select(node_type_ne=ExtendedOperator.CONSTANT_NODE)
subgraph = self.graph.induced_subgraph(nodes)
else:
subgraph = self.graph
visual_style = {}
visual_style["vertex_label_size"] = 5
visual_style["vertex_label"] = subgraph.vs["outputs"]
visual_style["layout"] = "drl"
visual_style["bbox"] = (800, 800)
visual_style["margin"] = 20
ig.plot(subgraph, target=axs, **visual_style)
axs.axis("off")
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.show()
def check(self):
"""Checks whether the graph is in a good state"""
assert self.graph.is_dag(), "The graph is not a DAG"
assert self.graph.is_directed(), "The graph is not directed"
# For simple NNs, the following checks should also pass
# Unfortunately, it is hard to tell whether the NN is simple or not.
# assert self.graph.is_simple(), "The graph has multiple edges between at least one pair of nodes"
# assert self.graph.is_connected('weak'), "The graph is not connected"
def topological_sort(self) -> typing.List[int]:
"""Sort the graph topologically
Returns:
typing.List[int]: The sorted indices of the nodes
"""
# Emulating DFS with LifoQueue(stack)
q = queue.LifoQueue()
visited = set()
indices = []
# We push all inputs nodes to the target queue.
inputs = [v for v in self.graph.vs if v['node_type'] == ExtendedOperator.INPUT_NODE]
other_input_nodes = [v for v in self.graph.vs if v['node_type'] >= 0 and v.indegree() == 0]
# Constants are all known, so just marking them here.
constants = [v for v in self.graph.vs if v['node_type'] == ExtendedOperator.CONSTANT_NODE]
for c in constants:
indices.append(c.index)
visited.add(c.index)
for e in c.out_edges():
v = e.target_vertex
if v not in other_input_nodes:
skip = False
for e in v.in_edges():
if e.source not in visited:
skip = True
break
if skip:
continue
if v['node_type'] >= 0:
other_input_nodes.append(v)
else:
if v['node_type'] != ExtendedOperator.OUTPUT_NODE:
type_name = ExtendedOperator(v['node_type']).type_name()
log.warning(
f'The child node of a constant node is of type {type_name}, which is unexpected'
)
for v in other_input_nodes:
if v['node_type'] not in (
ExtendedOperator.ASSIGN_VARIABLE,
ExtendedOperator.READ_VARIABLE,
ExtendedOperator.RANDOM_STANDARD_NORMAL,
ExtendedOperator.MULTINOMIAL,
ExtendedOperator.RANDOM_UNIFORM,
):
output_name = v['outputs'][0]
type_name = v['op'].type_name()
log.warning(f'{type_name}({output_name}) is an orphaned node, which is unexpected')
for i in reversed(inputs + other_input_nodes):
q.put(i)
while not q.empty():
v = q.get()
# Skip if already visited
if v.index in visited:
continue
# Ensure all input nodes are visited
skip = False
for e in v.in_edges():
if e.source not in visited:
skip = True
break
if skip:
continue
# Mark visited if the previous constraints are met
visited.add(v.index)
indices.append(v.index)
# Push the out nodes to the target queue
for e in reversed(v.out_edges()):
q.put(e.target_vertex)
return indices
def collect_operators(
self, ops: typing.Optional[typing.List[tfl.BaseOperator]] = None
) -> typing.List[tfl.BaseOperator]:
"""Collect ops
Args:
ops (typing.Optional[typing.List[tfl.BaseOperator]], optional): TFLite operators. Defaults to None.
Returns:
typing.List[tfl.BaseOperator]: operators with the numbered index
"""
# We define our custom for figuring out a better order than using `self.graph.topological_sorting()`
if ops is None:
ids = self.topological_sort()
nodes = (self.graph.vs[idx] for idx in ids)
filtered_nodes = (node for node in nodes if node['node_type'] >= 0)
ops: typing.List[tfl.BaseOperator] = (x['op'] for x in filtered_nodes)
log.debug('Collecting operators...')
result = []
for idx, op in enumerate(ops):
log.debug(f'[{idx}] {op.type_name()} {op.inputs} -> {op.outputs}')
op.op.index = idx
op.tfl_inputs_idx = [x.index for x in op.inputs]
op.tfl_outputs_idx = [x.index for x in op.outputs]
op.tfl_intermediates_idx = [x.index for x in op.intermediates]
result.append(op)
return result
def collect_tensor_buffers(
self,
labels: typing.Set[str] = None,
inputs: typing.List[str] = None,
outputs: typing.List[str] = None,
tensor_map: typing.Dict[str, tfl.Tensor] = None,
) -> typing.Tuple[typing.List[tfl.Tensor], typing.List[tfl.Buffer], typing.List[int], typing.List[int]]:
""" Collect tensors, buffers and I/O indices
Args:
labels (typing.Set[str], optional): TFLite tensor names. Defaults to None.
inputs (typing.List[str], optional): Input tensor names. Defaults to None.
outputs (typing.List[str], optional): Output tensor names. Defaults to None.
tensor_map (typing.Dict[str, tfl.Tensor], optional): All tensors. Defaults to None.
Returns:
typing.Tuple[typing.List[tfl.Tensor], typing.List[tfl.Buffer], typing.List[int], typing.List[int]]: \
tensors, buffers with the numbered index and I/O indices
"""
if labels is None:
labels = set(self.graph.es['label'])
if inputs is None:
inputs = self.inputs
if outputs is None:
outputs = self.outputs
if tensor_map is None:
tensor_map = self.tensor_map
tensor_idx = 0
buffer_idx = 1
tensors = []
buffers = [tfl.Buffer(bytes(0))]
input_idx = [-1] * len(inputs)
output_idx = [-1] * len(outputs)
for label in labels:
tensor: tfl.Tensor = tensor_map[label]
if tensor.index != -1:
if tensor.is_variable:
tensor.buffer.index = 0
tensor.index = tensor_idx
tensor_idx += 1
tensors.append(tensor)
if tensor.buffer is not None and tensor.is_variable is False:
tensor.buffer.index = buffer_idx
buffer_idx += 1
buffers.append(tensor.buffer)
if label in inputs:
item_indices = [i for i, x in enumerate(inputs) if x == label]
for item_idx in item_indices:
input_idx[item_idx] = tensor.index
if label in outputs:
item_indices = [i for i, x in enumerate(outputs) if x == label]
for item_idx in item_indices:
output_idx[item_idx] = tensor.index
missing_inputs = [name for name, _ in filter(lambda x: x[1] < 0, zip(inputs, input_idx))]
missing_outputs = [name for name, _ in filter(lambda x: x[1] < 0, zip(outputs, output_idx))]
assert len(missing_outputs) == 0, f'Some output nodes are missing: {missing_outputs}'
if len(missing_inputs) != 0:
warnings.warn(f'Some input nodes are missing: {missing_inputs}, will try to add them into graph')
for name in missing_inputs:
tensor = self.tensor_map[name]
tensor.index = tensor_idx
tensor_idx += 1
tensors.append(tensor)
item_idx = inputs.index(name)
input_idx[item_idx] = tensor.index
return tensors, buffers, input_idx, output_idx
def convert(self, tflite_path: str):
"""Convert from the TinyNeuralNetwork Graph to the tflite model
Args:
tflite_path ([str]): Path of the generated tflite model
"""
# Collect multiple data to build a tflite model
tensors, buffers, input_idx, output_idx = self.collect_tensor_buffers()
ops = self.collect_operators()
# Construct the flatbuffer model
tflite_model = self.build_model(ops, tensors, buffers, input_idx, output_idx)
# Check output directory
tflite_dir = os.path.abspath(os.path.dirname(tflite_path))
os.makedirs(tflite_dir, exist_ok=True)
# Write to file
with open(tflite_path, 'wb') as f:
f.write(tflite_model)
full_ops = ops
orig_tflite_path = tflite_path
for v in self.graph.vs:
if v['op'] is None:
continue
orig_op = v['op'].extra_hints.get('orig_float', None)
if orig_op is None:
continue
dq_op = v['op']
op_dict: typing.Dict[str, tfl.BaseOperator] = {'float': orig_op, 'dq': dq_op}
index = full_ops.index(dq_op)
for k, op in op_dict.items():
# Collect multiple data to build a tflite model
inputs = [x.name for x in op.inputs if x.buffer is None and not isinstance(x, tfl.OptionalTensor)]
outputs = [x.name for x in op.outputs if x.buffer is None and not isinstance(x, tfl.OptionalTensor)]
tensor_map = {t.name: t for t in op.inputs + op.outputs}
labels = tensor_map.keys()
tensors, buffers, input_idx, output_idx = self.collect_tensor_buffers(
labels, inputs, outputs, tensor_map
)
ops = self.collect_operators([op])
# Construct the flatbuffer model
tflite_model = self.build_model(ops, tensors, buffers, input_idx, output_idx)
fn, ext = os.path.splitext(orig_tflite_path)
fn += f'_{k}_{index}'
tflite_path = f'{fn}{ext}'
# Check output directory
tflite_dir = os.path.abspath(os.path.dirname(tflite_path))
os.makedirs(tflite_dir, exist_ok=True)
# Write to file
with open(tflite_path, 'wb') as f:
f.write(tflite_model)
def build_model(
self,
ops: typing.List[tfl.BaseOperator],
tensors: typing.List[tfl.Tensor],
buffers: typing.List[tfl.Buffer],
input_idx: typing.List[int],
output_idx: typing.List[int],
) -> bytearray:
"""Build the flatbuffer model
Args:
ops (typing.List[tfl.BaseOperator]): TFLite operators
tensors (typing.List[tfl.Tensor]): TFLite tensors
buffers (typing.List[tfl.Buffer]): TFLite buffers
input_idx (typing.List[int]): The indices of the input tensors
output_idx (typing.List[int]): The indices of the output tensors
Returns:
bytearray: The built flatbuffer model
"""
# Start flatbuffer
builder = flatbuffers.Builder(0)
# Write data into flatbuffer
tensor_offsets = [t.build(builder) for t in tensors]
op_offsets = [op.build(builder) for op in ops]
opcode_offsets = [op.op.build(builder) for op in ops]
buffer_offsets = [buffer.build(builder) for buffer in buffers]
# Build Subgraph
subgraph = tfl.SubGraph()
subgraph.tensors.extend(tensor_offsets)
subgraph.inputs.extend(input_idx)
subgraph.outputs.extend(output_idx)
subgraph.operators.extend(op_offsets)
# Build Model
model = tfl.Model()
model.buffers.extend(buffer_offsets)
model.subgraphs.append(subgraph.build(builder))
model.opcodes.extend(opcode_offsets)
model = model.build(builder)
builder.Finish(model, b"TFL3")
# Finish Model
tflite_model = builder.Output()
return tflite_model