tinynn/util/converter_util.py (141 lines of code) (raw):
import json
import os
import typing
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
from tinynn.util.util import get_logger
log = get_logger(__name__)
def tensor_config(tensors: typing.List[torch.Tensor], transpose: typing.List[bool], with_shape: bool) -> OrderedDict:
"""Generate the tensor info needed for the config file for the TinyNeuralNetwork converter
Args:
tensors (typing.List[torch.Tensor]): The tensors to gather info
transpose (typing.List[bool]): Whether to perform nchw-nhwc transpose for the tensors
with_shape (bool): Whether to dump the shape of the tensor
Returns:
OrderedDict: The info of the tensors
"""
tensor_list = []
for t, trans in zip(tensors, transpose):
tensor_dict = OrderedDict()
if with_shape:
tensor_dict['shape'] = list(t.shape)
try:
tensor_dict['type'] = str(t.detach().numpy().dtype)
except Exception:
type_str = str(t.dtype)
if 'torch.q' in type_str:
type_str = type_str.replace('torch.q', '')
else:
type_str = type_str.replace('torch.', '')
tensor_dict['type'] = type_str
if trans is None:
trans = len(t.shape) == 4
tensor_dict['transpose'] = trans
tensor_list.append(tensor_dict)
return tensor_list
def generate_converter_config(
inputs: typing.List[torch.Tensor],
outputs: typing.List[torch.Tensor],
input_transpose: typing.Union[typing.Iterable[bool], bool],
output_transpose: typing.Union[typing.Iterable[bool], bool],
export_file: str,
tflite_path: typing.Optional[str] = None,
config_path: typing.Optional[str] = None,
):
""" Generate a config file that will work for the TinyNeuralNetwork converter
Args:
inputs (typing.List[torch.Tensor]): The input tensors
outputs (typing.List[torch.Tensor]): The output tensors
input_transpose (typing.Union[typing.Iterable[bool], bool]): The flag whether to insert transpose after the \
input nodes
output_transpose (typing.Union[typing.Iterable[bool], bool]): The flag whether to insert transpose after the \
output nodes
export_file (str): The path of the generate torchscript model
tflite_path (str): The path of the generate tflite model. Defaults to None.
config_path (str): The path of the generate config. Defaults to None.
Raises:
AssertionError: input transpose should either be boolean or list of booleans, under the latter condition, \
the size of the list should be the same of that of the inputs
"""
if type(input_transpose) in (tuple, list):
if len(input_transpose) != len(inputs) or not all((type(x) is bool for x in input_transpose)):
raise AssertionError('input transpose should either be boolean or list of booleans')
elif type(input_transpose) is bool or input_transpose is None:
input_transpose = [input_transpose] * len(inputs)
else:
raise AssertionError('input transpose should either be boolean or list of booleans')
if type(output_transpose) in (tuple, list):
if len(output_transpose) != len(outputs) or not all((type(x) is bool for x in output_transpose)):
raise AssertionError('output transpose should either be boolean or list of booleans')
elif type(output_transpose) is bool or output_transpose is None:
output_transpose = [output_transpose] * len(inputs)
else:
raise AssertionError('output transpose should either be boolean or list of booleans')
if tflite_path is None:
tflite_path = export_file.replace('.pt', '.tflite')
json_obj = OrderedDict(
{
'src_model': export_file,
'dst_model': tflite_path,
'inputs': tensor_config(inputs, input_transpose, True),
'outputs': tensor_config(outputs, output_transpose, False),
}
)
if config_path is None:
config_path = export_file.replace('.pt', '.json')
with open(config_path, 'w') as f:
json.dump(json_obj, f, indent=4)
def export_converter_files(
model: nn.Module,
dummy_input: typing.Union[torch.Tensor, typing.Iterable[torch.Tensor]],
export_dir: typing.Optional[str] = None,
model_name: typing.Optional[str] = None,
input_transpose: typing.Optional[typing.Union[bool, typing.Iterable[bool]]] = None,
output_transpose: typing.Optional[typing.Union[bool, typing.Iterable[bool]]] = None,
dump_graph: bool = False,
):
""" Automatically generate required files for the model converter
Args:
model (nn.Module): The input model
dummy_input (typing.Union[torch.Tensor, typing.Iterable[torch.Tensor]]): A viable input to the model
export_dir (typing.Optional[str], optional): Directory to use for exporting. Defaults to None(os.getcwd()).
model_name (typing.Optional[str], optional): File name for exporting. Defaults to None("jit_model").
input_transpose (typing.Optional[typing.Union[bool, typing.Iterable[bool]]], optional): Whether to transpose \
the input(s). Defaults to None(True for 4d-input, False otherwise).
output_transpose (typing.Optional[typing.Union[bool, typing.Iterable[bool]]], optional): Whether to transpose \
the input(s). Defaults to None(True for 4d-output, False otherwise).
dump_graph (bool, optional): Whether to print the traced graph. Defaults to False.
"""
if export_dir is None:
export_dir = os.getcwd()
if model_name is None:
model_name = 'jit_model'
export_file = os.path.abspath(os.path.join(export_dir, f'{model_name}.pt'))
script = torch.jit.trace(model, dummy_input)
if dump_graph:
log.info(script.inlined_graph)
os.makedirs(export_dir, exist_ok=True)
torch.jit.save(script, export_file)
if type(dummy_input) not in (tuple, list):
inputs = [dummy_input]
else:
inputs = dummy_input
output = model(*inputs)
if type(output) not in (tuple, list):
outputs = [output]
else:
new_output = []
for item in output:
if type(item) in (tuple, list):
new_output.extend(item)
else:
new_output.append(item)
outputs = new_output
generate_converter_config(inputs, outputs, input_transpose, output_transpose, export_file)
def get_tensor_details(config):
input_shapes = []
input_transpose = []
input_dtypes = []
output_transpose = []
input_shapes.extend((inp['shape'] for inp in config['inputs']))
input_transpose.extend((inp['transpose'] for inp in config['inputs']))
input_dtypes.extend((inp['type'] for inp in config['inputs']))
output_transpose.extend((outp['transpose'] for outp in config['outputs']))
return input_shapes, input_transpose, input_dtypes, output_transpose
def prepare_input_arrays(input_shapes, input_transpose, input_dtypes):
inputs = []
for i in range(len(input_shapes)):
input_shape = input_shapes[i]
tranpose = input_transpose[i]
input_data = np.zeros(input_shape, dtype=input_dtypes[i])
if tranpose:
input_data = input_data.transpose((0, 2, 3, 1))
inputs.append(input_data)
return inputs
def data_to_pytorch(inputs, input_transpose):
torch_inputs = list(map(torch.from_numpy, inputs))
for i in range(len(torch_inputs)):
if input_transpose[i]:
torch_inputs[i] = torch_inputs[i].permute(0, 3, 1, 2)
return torch_inputs
def parse_config(
json_file: str, prepare_inputs: bool = True
) -> typing.Tuple[str, str, typing.List[bool], typing.List[torch.Tensor], typing.List[bool]]:
"""Parses the configuration file for converter
Args:
json_file (str): The path of the configuration file
prepare_inputs (bool, optional): Whether to prepare inputs. Defaults to True.
Returns:
tflite_model_path (str): The path used to place the generated tflite model
torch_model_path (str): The path of the torchscript model
input_transpose (typing.List[bool]): Flag variables whether the inputs will be transposed (nchw -> nhwc)
torch_inputs (typing.List[torch.Tensor]): The prepared inputs if prepare_inputs is True, otherwise None
output_transpose (typing.List[bool]): Flag variables whether the outputs will be transposed (nchw -> nhwc)
"""
with open(json_file, 'r') as f:
config = json.load(f)
tflite_model_path = config['dst_model']
torch_model_path = config['src_model']
if prepare_inputs:
input_shapes, input_transpose, input_dtypes, output_transpose = get_tensor_details(config)
inputs = prepare_input_arrays(input_shapes, input_transpose, input_dtypes)
torch_inputs = data_to_pytorch(inputs, input_transpose)
else:
torch_inputs = None
return torch_model_path, tflite_model_path, input_transpose, torch_inputs, output_transpose