tzrec/acc/trt_utils.py (143 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Any, Dict, List, Optional, Sequence import torch # cpu image has no torch_tensorrt try: import torch_tensorrt except Exception: pass from torch import nn from torch.profiler import ProfilerActivity, profile, record_function from tzrec.acc.utils import is_debug_trt from tzrec.models.model import ScriptWrapper from tzrec.utils.fx_util import symbolic_trace from tzrec.utils.logging_util import logger def trt_convert( exp_program: torch.export.ExportedProgram, # pyre-ignore [2] inputs: Optional[Sequence[Sequence[Any]]], ) -> torch.fx.GraphModule: """Convert model use trt. Args: exp_program (torch.export.ExportedProgram): Source exported program inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): inputs Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT """ logger.info("trt convert start...") torch_tensorrt.runtime.set_multi_device_safe_mode(True) enabled_precisions = {torch.float32} # Workspace size for TensorRT workspace_size = 2 << 30 # Maximum number of TRT Engines # (Lower value allows more graph segmentation) min_block_size = 2 # use script model , unsupported the inputs : dict if is_debug_trt(): with torch_tensorrt.logging.graphs(): optimized_model = torch_tensorrt.dynamo.compile( exp_program, inputs, # pyre-ignore [6] enabled_precisions=enabled_precisions, workspace_size=workspace_size, min_block_size=min_block_size, hardware_compatible=True, assume_dynamic_shape_support=True, # truncate_long_and_double=True, allow_shape_tensors=True, ) else: optimized_model = torch_tensorrt.dynamo.compile( exp_program, inputs, # pyre-ignore [6] enabled_precisions=enabled_precisions, workspace_size=workspace_size, min_block_size=min_block_size, hardware_compatible=True, assume_dynamic_shape_support=True, # truncate_long_and_double=True, allow_shape_tensors=True, ) logger.info("trt convert end") return optimized_model class ScriptWrapperList(ScriptWrapper): """Model inference wrapper for jit.script. ScriptWrapperList for trace the ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt) and return a list of Tensor instead of a dict of Tensor """ def __init__(self, module: nn.Module) -> None: super().__init__(module) # pyre-ignore [15] def forward( self, data: Dict[str, torch.Tensor], # pyre-ignore [9] device: torch.device = "cpu", ) -> List[torch.Tensor]: """Predict the model. Args: data (dict): a dict of input data for Batch. device (torch.device): inference device. Return: predictions (dict): a dict of predicted result. """ batch = self.get_batch(data, device) return self.model.predict(batch) class ScriptWrapperTRT(nn.Module): """Model inference wrapper for jit.script.""" def __init__(self, embedding_group: nn.Module, dense: nn.Module) -> None: super().__init__() self.embedding_group = embedding_group self.dense = dense def forward( self, data: Dict[str, torch.Tensor], # pyre-ignore [9] device: torch.device = "cuda:0", ) -> Dict[str, torch.Tensor]: """Predict the model. Args: data (dict): a dict of input data for Batch. device (torch.device): inference device. Return: predictions (dict): a dict of predicted result. """ grouped_features = self.embedding_group(data, device) y = self.dense(grouped_features) return y def get_trt_max_batch_size() -> int: """Get trt max batch size. Returns: int: max_batch_size """ return int(os.environ.get("TRT_MAX_BATCH_SIZE", 2048)) def get_trt_max_seq_len() -> int: """Get trt max seq len. Returns: int: max_seq_len """ return int(os.environ.get("TRT_MAX_SEQ_LEN", 100)) def export_model_trt( model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str ) -> None: """Export trt model. Args: model (nn.Module): the model data (Dict[str, torch.Tensor]): the test data save_dir (str): model save dir """ # ScriptWrapperList for trace the ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt) emb_trace_gpu = ScriptWrapperList(model.model.embedding_group) emb_res = emb_trace_gpu(data, "cuda:0") emb_trace_gpu = symbolic_trace(emb_trace_gpu) emb_trace_gpu = torch.jit.script(emb_trace_gpu) # dynamic shapes max_batch_size = get_trt_max_batch_size() max_seq_len = get_trt_max_seq_len() batch = torch.export.Dim("batch", min=1, max=max_batch_size) dynamic_shapes_list = [] values_list_cuda = [] for i, value in enumerate(emb_res): v = value.detach().to("cuda:0") dict_dy = {0: batch} if v.dim() == 3: # workaround -> 0/1 specialization if v.size(1) < 2: v = torch.zeros(v.size(0), 2, v.size(2), device="cuda:0", dtype=v.dtype) dict_dy[1] = torch.export.Dim("seq_len" + str(i), min=1, max=max_seq_len) if v.size(0) < 2: v = torch.zeros((2,) + v.size()[1:], device="cuda:0", dtype=v.dtype) values_list_cuda.append(v) dynamic_shapes_list.append(dict_dy) # convert dense dense = model.model.dense logger.info("dense res: %s", dense(values_list_cuda)) dense_layer = symbolic_trace(dense) dynamic_shapes = {"args": dynamic_shapes_list} exp_program = torch.export.export( dense_layer, (values_list_cuda,), dynamic_shapes=dynamic_shapes ) dense_layer_trt = trt_convert(exp_program, values_list_cuda) dict_res = dense_layer_trt(values_list_cuda) logger.info("dense trt res: %s", dict_res) # save combined_model combined_model = ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt) result = combined_model(data, "cuda:0") logger.info("combined model result: %s", result) # combined_model = symbolic_trace(combined_model) combined_model = torch.jit.trace( combined_model, example_inputs=(data,), strict=False ) scripted_model = torch.jit.script(combined_model) # pyre-ignore [16] scripted_model.save(os.path.join(save_dir, "scripted_model.pt")) if is_debug_trt(): with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, ) as prof: with record_function("model_inference_dense"): dict_res = dense(values_list_cuda) logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, ) as prof: with record_function("model_inference_dense_trt"): dict_res = dense_layer_trt(values_list_cuda) logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) model_gpu_combined = torch.jit.load( os.path.join(save_dir, "scripted_model.pt"), map_location="cuda:0" ) res = model_gpu_combined(data) logger.info("final res: %s", res) with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, ) as prof: with record_function("model_inference_combined_trt"): dict_res = model_gpu_combined(data) logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) logger.info("trt convert success")