tinynn/converter/operators/torch/prim.py (162 lines of code) (raw):
import torch
import numpy as np
from . import PrimOperatorConverter
from .. import tflite as tfl
from tinynn.util.util import get_logger
log = get_logger(__name__, 'INFO')
class PrimConstantConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
if attrs is not None:
v, vk = attrs.get('value', (None, None))
vt = v.dtype if hasattr(v, "dtype") else type(v).__name__
log.debug(f'{node.kind()} {self.input_names} -> {self.output_names} {vk} {vt}')
self.output_tensors.append(v)
else:
self.output_tensors = None
class PrimTupleConstructConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
self.output_tensors.append(tuple(self.input_tensors))
class PrimDictConstructConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
assert len(self.input_tensors) % 2 == 0
result = {}
for key, value in zip(self.input_tensors[::2], self.input_tensors[1::2]):
result[key] = value
self.output_tensors.append(result)
class PrimListConstructConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
self.output_tensors.append(list(self.input_tensors))
graph_converter.add_iterable_pair(self.input_names, self.output_names, 'output')
class PrimListUnpackConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
assert type(self.input_tensors[0]) in (list, tuple)
assert len(self.input_tensors[0]) == len(self.output_names)
self.output_tensors.extend(self.input_tensors[0])
try:
name = self.input_names[0]
input_names = graph_converter.get_list_expanded_names(name)
inputs = self.to_tfl_tensors(input_names, self.input_tensors[0], graph_converter=graph_converter)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
# Connect the tensors with a no-op that will be removed in the optimize passes
for i, o in zip(inputs, outputs):
s = np.array(o.shape, dtype='int32')
st = self.create_attr_tensor(s)
graph_converter.add_operator(tfl.ReshapeOperator([i, st], [o], o.shape))
except KeyError:
# The input is not tracked, nothing needs to be done to the graph converter
pass
class PrimGetAttrConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
name, name_type = attrs.get('name', (None, None))
if name is not None and name_type == 's':
v = getattr(self.input_tensors[0], name)
self.output_tensors.append(v)
else:
assert False, f"prim::GetAttr({self.output_names[0]}) needs attribute `name` with type str"
class PrimNumToTensorConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
assert type(self.input_tensors[0]) in (int, float)
assert len(self.input_tensors) == len(self.output_names)
t = torch.tensor(self.input_tensors[0])
if t.dtype == torch.int64:
log.warning(
f'{self.output_names[0]} is of type int64, which is unsupported in TFLite, trying to downcast to int32'
)
t = t.to(dtype=torch.int32)
if t.dtype == torch.float64:
log.warning(
f'{self.output_names[0]} is of type float64, which is unsupported in TFLite, trying to downcast to'
' float32'
)
t = t.to(dtype=torch.float32)
self.output_tensors.append(t)
class PrimIfConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
assert len(self.input_tensors) == 1
assert isinstance(self.input_tensors[0], (bool, int))
assert len(self.output_names) == 0
cond_var_name = self.input_names[0]
assert cond_var_name not in graph_converter.tensor_map, 'Dynamic control flow is not supported'
blocks = list(node.blocks())
assert len(blocks) == 2
if self.input_tensors[0] in (True, 1):
self.output_nodes.extend(blocks[0].nodes())
else:
self.output_nodes.extend(blocks[1].nodes())
class PrimGetItemConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
input_tensor = self.input_tensors[0]
idx = self.input_tensors[1]
self.output_tensors.append(input_tensor[idx])
class PrimLenConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
input_tensor = self.input_tensors[0]
self.output_tensors.append(len(input_tensor))
class PrimConstantChunkConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
chunks, chunks_type = attrs.get('chunks', (None, None))
dim, dim_type = attrs.get('dim', (None, None))
if chunks is None or chunks_type != 'i':
assert False, f"prim::ConstantChunk({self.output_names[0]}) needs attribute `chunks` with type int"
if dim is None or dim_type != 'i':
assert False, f"prim::ConstantChunk({self.output_names[0]}) needs attribute `dim` with type int"
v = torch.chunk(self.input_tensors[0], chunks, dim)
self.output_tensors.extend(v)
# Graph operations only take place when the input tensor is tracked
if self.input_names[0] in graph_converter.tensor_map:
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
if dim < 0:
dim += len(self.input_tensors[0].shape)
dim_size = self.input_tensors[0].size(dim)
if chunks > dim_size:
chunks = dim_size
input_tensor = self.find_or_create_input(0, graph_converter)
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))
if dim_size % chunks != 0:
size_splits = np.array([t.size(dim) for t in self.output_tensors], dtype='int32')
chunks = len(size_splits)
split_tensor = self.create_attr_tensor(size_splits)
graph_converter.add_operator(
tfl.SplitVOperator([input_tensor, split_tensor, dim_tensor], outputs, chunks)
)
else:
graph_converter.add_operator(tfl.SplitOperator([dim_tensor, input_tensor], outputs, chunks))
class PrimPythonOpConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
subgraph = attrs['Subgraph'][0]
param_node = subgraph.param_node()
return_node = subgraph.return_node()
self.output_tensors.append(node.pyobj()(*self.input_tensors, *node.scalar_args()))
self.output_nodes.append(param_node)
self.output_nodes.extend(subgraph.nodes())
self.output_nodes.append(return_node)
def prepare_scope_tensors(self, node, attrs, args, graph_converter, scope_name):
subgraph = attrs['Subgraph'][0]
# input tensors
param_node = subgraph.param_node()
input_tensors = [self.find_or_create_input(i, graph_converter) for i in range(len(self.input_tensors))]
subgraph_input_names = [self.get_tensor_name(x.debugName(), scope_name) for x in param_node.outputs()]
for name, t in zip(subgraph_input_names, input_tensors):
graph_converter.constant_mapping[name] = t
# output tensors
return_node = subgraph.return_node()
subgraph_output_names = [self.get_tensor_name(x.debugName(), scope_name) for x in return_node.inputs()]
output_tensors = self.to_tfl_tensors(self.output_names, self.output_tensors)
for name, t in zip(subgraph_output_names, output_tensors):
graph_converter.constant_mapping[name] = t
class PrimReturnConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
for i, name in enumerate(self.input_names):
assert name in graph_converter.constant_mapping
if name in graph_converter.tensor_map:
input_tensor = self.find_or_create_input(i, graph_converter)
output_tensor = graph_converter.constant_mapping[name]
inputs = [input_tensor, self.create_attr_tensor(input_tensor.shape, name=f'{name}_return_attr')]
outputs = [output_tensor]
graph_converter.add_operator(tfl.ReshapeOperator(inputs, outputs, input_tensor.shape))
class PrimParamConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
for i, name in enumerate(self.output_names):
assert name in graph_converter.constant_mapping
input_tensor = graph_converter.constant_mapping[name]
output_tensor = self.to_tfl_tensors([name], [input_tensor.tensor])[0]
self.output_tensors.append(torch.from_numpy(input_tensor.tensor))
if input_tensor.name in graph_converter.tensor_map:
inputs = [input_tensor, self.create_attr_tensor(input_tensor.shape, name=f'{name}_return_attr')]
outputs = [output_tensor]
graph_converter.add_operator(tfl.ReshapeOperator(inputs, outputs, input_tensor.shape))