tinynn/converter/utils/tensorrt.py (252 lines of code) (raw):
import os
import re
from typing import Optional
import onnx
import onnxruntime as ort
import numpy as np
import pycuda.autoinit # noqa: F401
import pycuda.driver as cuda
import tensorrt as trt
from contextlib import contextmanager
# A logger with the specific log level
# Available levels: WARNING, ERROR, VERBOSE, INFO
TRT_BUILD_LOGGER = trt.Logger(trt.Logger.WARNING)
TRT_EVAL_LOGGER = trt.Logger(trt.Logger.ERROR)
# Constant batch, currently it cannot be disabled
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
class MyLogger(trt.ILogger):
def __init__(self, inner):
trt.ILogger.__init__(self)
self.info_log = None
self.inner = inner
def log(self, severity, msg):
if msg.startswith('Engine Layer Information:'):
self.info_log = msg
self.inner.log(severity, msg)
class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
def allocate_buffers(engine, dynamic_shapes_for_build, dynamc_shapes_for_eval):
"""Allocates all buffers required for an engine, i.e. host/device inputs/outputs."""
inputs = []
outputs = []
bindings = []
names_map_i = {}
names_map_o = {}
stream = cuda.Stream()
for idx, binding in enumerate(engine):
if binding in dynamc_shapes_for_eval:
size = trt.volume(dynamc_shapes_for_eval[binding])
elif binding in dynamic_shapes_for_build:
size = trt.volume(dynamic_shapes_for_build[binding][1])
else:
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding))
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Input/Output node name
if engine.binding_is_input(binding):
names_map_i[binding] = len(inputs)
inputs.append(HostDeviceMem(host_mem, device_mem))
print('input:', engine.get_binding_shape(binding), dtype, binding)
else:
names_map_o[binding] = len(outputs)
outputs.append(HostDeviceMem(host_mem, device_mem))
print('output:', engine.get_binding_shape(binding), dtype, binding)
return inputs, outputs, bindings, stream, names_map_i, names_map_o
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
"""This function is generalized for multiple inputs/outputs.
inputs and outputs are expected to be lists of HostDeviceMem objects."""
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
return [out.host for out in outputs]
def GiB(val):
"""Numerical value for GiB"""
return val * 1 << 30
def add_profile(builder, config, dynamic_shapes_for_build):
if len(dynamic_shapes_for_build) > 0:
profile = builder.create_optimization_profile()
for inp, (min_shape, opt_shape, max_shape) in dynamic_shapes_for_build.items():
profile.set_shape(inp, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
@contextmanager
def build_engine(model_path, logger, build_with_fp16, build_with_int8, build_with_workspace, dynamic_shapes_for_build):
"""Build the TensorRT engine from a ONNX model"""
with trt.Builder(logger) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(
network, logger
) as parser:
with open(model_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
with builder.create_builder_config() as config:
if build_with_fp16:
config.flags = 1 << (int)(trt.BuilderFlag.FP16)
if build_with_int8:
config.flags = 1 << (int)(trt.BuilderFlag.INT8)
config.max_workspace_size = GiB(build_with_workspace)
add_profile(builder, config, dynamic_shapes_for_build)
yield builder.build_engine(network, config)
def get_needed_outputs(lines):
needed_outputs = set()
dtype_start = None
shape_start = None
start_end_dict = {'[': ']', '(': ')'}
dtype_end = None
shape_end = None
sep = None
for line in lines:
if line.startswith('Layer('):
line = line.rstrip('\n')
op = re.findall(r'Layer\((.*?)\)', line)[0]
if '->' not in line:
continue
pos = line.find('-> ')
outputs = line[pos + 3 :]
if dtype_start is None:
match = next(re.finditer(r'( *(\[|\())(Int32|Float|Half|Int8)(\(|\[)', outputs))
dtype_start = match.group(1)
shape_start = match.group(4)
dtype_end = start_end_dict[dtype_start.lstrip()]
shape_end = start_end_dict[shape_start]
if sep is None:
end_str = f'{shape_end}{dtype_end}'
pos = outputs.find(end_str)
assert pos > 0
start = pos + len(shape_end) + len(dtype_end)
if start != len(outputs):
end = outputs.find(' ', start)
sep = f'{end_str}{outputs[start:end]} '
if sep is not None and sep in outputs:
outputs = outputs.split(sep)
else:
outputs = [outputs]
for output in outputs:
if dtype_start not in output:
continue
pos = output.rfind(dtype_start)
output_name = output[:pos]
if '[' in output_name or '(' in output_name or '+' in output_name:
continue
if re.match('Reformatted .* to .*', output):
continue
print('Observing', op, repr(output_name))
needed_outputs.add(output_name)
return needed_outputs
def add_outputs_for_onnx_model(needed_outputs, onnx_path, new_onnx_path):
model = onnx.load(onnx_path)
orig_outputs = set()
for node in model.graph.output:
orig_outputs.add(node.name)
print()
for node in model.graph.node:
for output in node.output:
if output in needed_outputs and output not in orig_outputs:
model.graph.output.extend([onnx.ValueInfoProto(name=output)])
print('Added output:', output)
inferred_model = onnx.shape_inference.infer_shapes(model)
onnx.save_model(inferred_model, new_onnx_path)
print('Modified model saved at', new_onnx_path)
def compare_onnx_tensorrt(
onnx_path: str,
build_with_fp16: bool,
build_with_int8: bool,
build_with_workspace: int = 4,
dynamic_shapes_for_build: Optional[dict] = None,
dynamc_shapes_for_eval: Optional[dict] = None,
input_path_mapping: Optional[dict] = None,
):
if dynamic_shapes_for_build is None:
dynamic_shapes_for_build = {}
if dynamc_shapes_for_eval is None:
dynamc_shapes_for_eval = {}
if input_path_mapping is None:
input_path_mapping = {}
onnx_fn, onnx_ext = os.path.splitext(onnx_path)
new_onnx_path = f'{onnx_fn}_with_outputs{onnx_ext}'
new_trt_path = f'{onnx_fn}_with_outputs{onnx_ext}'
print('Building TensorRT engine with', onnx_path)
logger = MyLogger(TRT_BUILD_LOGGER)
with build_engine(
onnx_path, logger, build_with_fp16, build_with_int8, build_with_workspace, dynamic_shapes_for_build
) as engine:
pass
assert logger.info_log is not None, "Engine layer information is missing"
lines = logger.info_log.splitlines()
needed_outputs = get_needed_outputs(lines)
add_outputs_for_onnx_model(needed_outputs, onnx_path, new_onnx_path)
with build_engine(
new_onnx_path, TRT_BUILD_LOGGER, build_with_fp16, build_with_int8, build_with_workspace
) as engine:
with open(new_trt_path, 'wb') as f:
f.write(bytearray(engine.serialize()))
runtime = trt.Runtime(TRT_EVAL_LOGGER)
with open(new_trt_path, 'rb') as f:
engine_bytes = f.read()
engine = runtime.deserialize_cuda_engine(engine_bytes)
print('=' * 60)
print('input output tensors:')
inputs, outputs, bindings, stream, names_map_i, names_map_o = allocate_buffers(
engine, dynamic_shapes_for_build, dynamc_shapes_for_eval
)
print('=' * 60)
input_data = {}
with engine.create_execution_context() as context:
if True:
for binding in names_map_i:
dtype = trt.nptype(engine.get_binding_dtype(binding))
if binding in input_path_mapping:
assert (
binding in dynamc_shapes_for_eval
), "input_path_mapping and dynamc_shapes_for_eval should be specified together"
shape = dynamc_shapes_for_eval[binding]
data = np.fromfile(input_path_mapping[binding], dtype='uint8')
data = np.reshape(data.view(dtype), shape)
input_data[binding] = data
context.set_binding_shape(engine.get_binding_index(binding), shape)
else:
if binding in dynamc_shapes_for_eval:
shape = dynamc_shapes_for_eval[binding]
context.set_binding_shape(engine.get_binding_index(binding), shape)
else:
shape = engine.get_binding_shape(binding)
if -1 in shape:
shape = dynamic_shapes_for_build[binding][1]
context.set_binding_shape(engine.get_binding_index(binding), shape)
data = np.random.random(shape).astype(dtype)
input_data[binding] = data
np.copyto(inputs[names_map_i[binding]].host, data.ravel())
output = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
output_keys = list(names_map_o)
opts = ort.SessionOptions()
opts.intra_op_num_threads = 1
opts.inter_op_num_threads = 1
sess = ort.InferenceSession(new_onnx_path, providers=['CUDAExecutionProvider'], sess_options=opts)
res = sess.run(output_keys, input_data)
for i, binding in enumerate(output_keys):
onnx_val = res[i]
if onnx_val is None:
continue
dtype = trt.nptype(engine.get_binding_dtype(binding))
if dtype != onnx_val.dtype:
trt_val = np.reshape(output[names_map_o[binding]].view(dtype).astype(onnx_val.dtype), onnx_val.shape)
else:
trt_val = np.reshape(output[names_map_o[binding]].view(onnx_val.dtype), onnx_val.shape)
is_aligned = np.allclose(onnx_val, trt_val)
if is_aligned:
print(binding, 'matches:', is_aligned)
continue
onnx_val_ravel = onnx_val.ravel()
trt_val_ravel = trt_val.ravel()
cross_sim = np.dot(onnx_val_ravel, trt_val_ravel) / (
np.linalg.norm(onnx_val_ravel) * np.linalg.norm(trt_val_ravel)
)
is_aligned = is_aligned or cross_sim > 0.999
print(binding, 'matches:', is_aligned, 'cross_sim =', cross_sim)
if not is_aligned:
print('Top 10 values with maximum differences')
max_diff_indices = np.argsort(np.abs(onnx_val - trt_val).ravel())[::-1][:10]
print('TensorRT:')
print(trt_val_ravel[max_diff_indices])
print('ONNX:')
print(onnx_val_ravel[max_diff_indices])
print('-' * 60)