tzrec/acc/trt_utils.py (143 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, List, Optional, Sequence
import torch
# cpu image has no torch_tensorrt
try:
import torch_tensorrt
except Exception:
pass
from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function
from tzrec.acc.utils import is_debug_trt
from tzrec.models.model import ScriptWrapper
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import logger
def trt_convert(
exp_program: torch.export.ExportedProgram,
# pyre-ignore [2]
inputs: Optional[Sequence[Sequence[Any]]],
) -> torch.fx.GraphModule:
"""Convert model use trt.
Args:
exp_program (torch.export.ExportedProgram): Source exported program
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): inputs
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
"""
logger.info("trt convert start...")
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
enabled_precisions = {torch.float32}
# Workspace size for TensorRT
workspace_size = 2 << 30
# Maximum number of TRT Engines
# (Lower value allows more graph segmentation)
min_block_size = 2
# use script model , unsupported the inputs : dict
if is_debug_trt():
with torch_tensorrt.logging.graphs():
optimized_model = torch_tensorrt.dynamo.compile(
exp_program,
inputs,
# pyre-ignore [6]
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
hardware_compatible=True,
assume_dynamic_shape_support=True,
# truncate_long_and_double=True,
allow_shape_tensors=True,
)
else:
optimized_model = torch_tensorrt.dynamo.compile(
exp_program,
inputs,
# pyre-ignore [6]
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
hardware_compatible=True,
assume_dynamic_shape_support=True,
# truncate_long_and_double=True,
allow_shape_tensors=True,
)
logger.info("trt convert end")
return optimized_model
class ScriptWrapperList(ScriptWrapper):
"""Model inference wrapper for jit.script.
ScriptWrapperList for trace the ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt)
and return a list of Tensor instead of a dict of Tensor
"""
def __init__(self, module: nn.Module) -> None:
super().__init__(module)
# pyre-ignore [15]
def forward(
self,
data: Dict[str, torch.Tensor],
# pyre-ignore [9]
device: torch.device = "cpu",
) -> List[torch.Tensor]:
"""Predict the model.
Args:
data (dict): a dict of input data for Batch.
device (torch.device): inference device.
Return:
predictions (dict): a dict of predicted result.
"""
batch = self.get_batch(data, device)
return self.model.predict(batch)
class ScriptWrapperTRT(nn.Module):
"""Model inference wrapper for jit.script."""
def __init__(self, embedding_group: nn.Module, dense: nn.Module) -> None:
super().__init__()
self.embedding_group = embedding_group
self.dense = dense
def forward(
self,
data: Dict[str, torch.Tensor],
# pyre-ignore [9]
device: torch.device = "cuda:0",
) -> Dict[str, torch.Tensor]:
"""Predict the model.
Args:
data (dict): a dict of input data for Batch.
device (torch.device): inference device.
Return:
predictions (dict): a dict of predicted result.
"""
grouped_features = self.embedding_group(data, device)
y = self.dense(grouped_features)
return y
def get_trt_max_batch_size() -> int:
"""Get trt max batch size.
Returns:
int: max_batch_size
"""
return int(os.environ.get("TRT_MAX_BATCH_SIZE", 2048))
def get_trt_max_seq_len() -> int:
"""Get trt max seq len.
Returns:
int: max_seq_len
"""
return int(os.environ.get("TRT_MAX_SEQ_LEN", 100))
def export_model_trt(
model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str
) -> None:
"""Export trt model.
Args:
model (nn.Module): the model
data (Dict[str, torch.Tensor]): the test data
save_dir (str): model save dir
"""
# ScriptWrapperList for trace the ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt)
emb_trace_gpu = ScriptWrapperList(model.model.embedding_group)
emb_res = emb_trace_gpu(data, "cuda:0")
emb_trace_gpu = symbolic_trace(emb_trace_gpu)
emb_trace_gpu = torch.jit.script(emb_trace_gpu)
# dynamic shapes
max_batch_size = get_trt_max_batch_size()
max_seq_len = get_trt_max_seq_len()
batch = torch.export.Dim("batch", min=1, max=max_batch_size)
dynamic_shapes_list = []
values_list_cuda = []
for i, value in enumerate(emb_res):
v = value.detach().to("cuda:0")
dict_dy = {0: batch}
if v.dim() == 3:
# workaround -> 0/1 specialization
if v.size(1) < 2:
v = torch.zeros(v.size(0), 2, v.size(2), device="cuda:0", dtype=v.dtype)
dict_dy[1] = torch.export.Dim("seq_len" + str(i), min=1, max=max_seq_len)
if v.size(0) < 2:
v = torch.zeros((2,) + v.size()[1:], device="cuda:0", dtype=v.dtype)
values_list_cuda.append(v)
dynamic_shapes_list.append(dict_dy)
# convert dense
dense = model.model.dense
logger.info("dense res: %s", dense(values_list_cuda))
dense_layer = symbolic_trace(dense)
dynamic_shapes = {"args": dynamic_shapes_list}
exp_program = torch.export.export(
dense_layer, (values_list_cuda,), dynamic_shapes=dynamic_shapes
)
dense_layer_trt = trt_convert(exp_program, values_list_cuda)
dict_res = dense_layer_trt(values_list_cuda)
logger.info("dense trt res: %s", dict_res)
# save combined_model
combined_model = ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt)
result = combined_model(data, "cuda:0")
logger.info("combined model result: %s", result)
# combined_model = symbolic_trace(combined_model)
combined_model = torch.jit.trace(
combined_model, example_inputs=(data,), strict=False
)
scripted_model = torch.jit.script(combined_model)
# pyre-ignore [16]
scripted_model.save(os.path.join(save_dir, "scripted_model.pt"))
if is_debug_trt():
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference_dense"):
dict_res = dense(values_list_cuda)
logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference_dense_trt"):
dict_res = dense_layer_trt(values_list_cuda)
logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
model_gpu_combined = torch.jit.load(
os.path.join(save_dir, "scripted_model.pt"), map_location="cuda:0"
)
res = model_gpu_combined(data)
logger.info("final res: %s", res)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference_combined_trt"):
dict_res = model_gpu_combined(data)
logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
logger.info("trt convert success")