tzrec/acc/export_utils.py (182 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import os from typing import Any, Dict, List, Tuple, Type import torch import torch._prims_common as prims_utils import torch.nn.functional as F from torch import nn from torch._decomp import decomposition_table, register_decomposition from torch._export.verifier import ( OpOverload, SpecViolationError, Verifier, _check_val, getattr_recursive, is_functional, ) from torch._prims_common.wrappers import out_wrapper from torch.export import Dim from tzrec.acc.utils import is_trt from tzrec.utils.fx_util import symbolic_trace from tzrec.utils.logging_util import logger # add new aten._softmax decomposition which is supported by dynamo aten = torch._ops.ops.aten if aten._softmax.default in decomposition_table: del decomposition_table[aten._softmax.default] del decomposition_table[aten._softmax.out] # pyre-ignore [56] @register_decomposition(aten._softmax) @out_wrapper() def _softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor: # eager softmax returns a contiguous tensor. Ensure that decomp also returns # a contiguous tensor. x = x.contiguous() if half_to_float: assert x.dtype == torch.half computation_dtype, result_dtype = prims_utils.elementwise_dtypes( x, type_promotion_kind=prims_utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ) x = x.to(computation_dtype) x_max = torch.max(x, dim, keepdim=True).values unnormalized = torch.exp(x - x_max) result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) if not half_to_float: result = result.to(result_dtype) return result # now dynamo generate sym_int5 = sym_sum(sym_int1, sym_int2, sym_int3) op # instead of sym_int4 = sym_int1 + sym_int2; sym_int5 = sym_int4 + sym_int3. # patch _check_graph_module of Verifier temporarily to add sym_sum into allowed ops def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: def _allowed_getattr_types() -> Tuple[Type[Any], ...]: ret = self.allowed_getattr_types() assert not any(t is object for t in ret) return ret def _check_valid_op(op) -> None: def _allowed_builtin_ops() -> List: ret = self.allowed_builtin_ops() assert all(inspect.isbuiltin(op) for op in ret) return ret def _allowed_op_types() -> Tuple[Type[Any], ...]: ret = self.allowed_op_types() assert not any(t is object for t in ret) return ret # TODO Remove this allowlist. _allowed_torch_functions = ( torch.autograd.grad_mode.set_grad_enabled, torch.sym_sum, torch.sym_int, torch.sym_float, torch.sym_ite, torch.sym_max, torch.sym_min, torch.sym_not, torch.sym_sqrt, # TODO (tmanlaibaatar) # Predispatch export is able to contain autograd ops. # These will be modeled as HOO later torch._C._set_grad_enabled, torch.amp.autocast_mode._enter_autocast, torch.amp.autocast_mode._exit_autocast, torch.fx.experimental.symbolic_shapes.cast_symbool_to_symint_guardless, ) if not isinstance(op, _allowed_op_types()): if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions: raise SpecViolationError( f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" # NOQA f"Valid builtin ops: {_allowed_builtin_ops()}" f"Valid torch functions: {_allowed_torch_functions}" ) if isinstance(op, OpOverload): # All ops functional # TODO (tmanlaibaatar) more proper way is needed here if self.dialect != "TRAINING" and not is_functional(op): raise SpecViolationError(f"operator '{op}' is not functional") self.check_valid_op(op) for mod in gm.modules(): if not isinstance(mod, torch.fx.GraphModule): continue mod.graph.lint() for node in mod.graph.nodes: # TODO(T140410192): should have fake tensor for all dialects if node.op in {"call_module", "call_method"}: raise SpecViolationError( f"call_module is not valid: got a class '{node.target}' ", ) elif node.op == "call_function": _check_val(node) _check_valid_op(node.target) elif node.op == "get_attr": if not isinstance(node.target, str): raise SpecViolationError( f"Expected get_attr target to be string, but got {type(node.target)}" # NOQA ) attr = getattr_recursive(mod, node.target) if isinstance(attr, torch.nn.Module): def _is_type(name, ty): return isinstance(getattr(attr, name, None), ty) # NOQA if type(attr).__name__ == "LoweredBackendModule": if ( _is_type("backend_id", str) and _is_type("processed_bytes", bytes) and _is_type("compile_specs", list) and hasattr(attr, "original_module") ): continue else: backend_id = getattr(attr, "backend_id", None) processed_bytes = getattr(attr, "processed_bytes", None) compile_specs = getattr(attr, "compile_specs", None) raise SpecViolationError( f"Invalid get_attr type {type(attr)}. \n" f"LoweredBackendModule fields: " f"backend_id(str) : {type(backend_id)}, " f"processed_bytes(bytes) : {type(processed_bytes)}, " f"compile_specs(list) : {type(compile_specs)}" ) if not isinstance(attr, _allowed_getattr_types()): raise SpecViolationError( f"Invalid get_attr type {type(attr)}. \n" f"Valid get_attr types: {_allowed_getattr_types()}" ) elif node.op == "placeholder": _check_val(node) # TODO(zhxchen17) # elif node.op == "output": # _check_flattened_outputs() self.check_additional(gm) Verifier._check_graph_module = _check_graph_module def export_pm( model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str ) -> Tuple[torch.export.ExportedProgram, Dict[str, torch.Tensor]]: """Export a PyTorch model and its parameters. Args: model (nn.Module): The PyTorch model to export. data (Dict[str, torch.Tensor]): containing the model's input tensors. save_dir (str): The directory where the model should be saved. Returns: Tuple[torch.export.ExportedProgram, Dict[str, torch.Tensor]]: The exported program and its input data. """ gm = symbolic_trace(model) with open(os.path.join(save_dir, "gm.code"), "w") as f: f.write(gm.code) gm = gm.cuda() batch = Dim("batch") dynamic_shapes = {} for key in data: # .lengths if key.endswith(".lengths"): # user feats if key.split(".")[0] in model._data_parser.user_feats: assert data[key].shape[0] == 1 logger.info( "uniq user length fea %s length=%s" % (key, data[key].shape) ) dynamic_shapes[key] = {} else: logger.info("batch length fea=%s shape=%s" % (key, data[key].shape)) dynamic_shapes[key] = {0: batch} elif key == "batch_size": dynamic_shapes[key] = {} # dense values elif key.split(".")[0] in model._data_parser.dense_keys_list: # user feats if key.split(".")[0] in model._data_parser.user_feats: assert data[key].shape[0] == 1 logger.info("uniq user dense_fea=%s shape=%s" % (key, data[key].shape)) dynamic_shapes[key] = {} else: logger.info("batch dense_fea=%s shape=%s" % (key, data[key].shape)) dynamic_shapes[key] = {0: batch} # seq_dense or sparse else: if data[key].shape[0] < 2: data[key] = F.pad( data[key], [0, 0] * (len(data[key].shape) - 1) + [0, 2], mode="constant", ) data[key.split(".")[0] + ".lengths"][0] = data[key].shape[0] logger.info("sparse or seq dense fea=%s shape=%s" % (key, data[key].shape)) tmp_val_dim = Dim(key.replace(".", "__") + "__batch", min=0) dynamic_shapes[key] = {0: tmp_val_dim} # trt need contiguous format if is_trt(): data[key] = data[key].contiguous() logger.info("dynamic shapes=%s" % dynamic_shapes) exported_pg = torch.export.export( gm, args=(data,), dynamic_shapes=(dynamic_shapes,) ) export_path = os.path.join(save_dir, "exported_pg.py") with open(export_path, "w") as fout: fout.write(str(exported_pg)) exported_pg.module()(data) return (exported_pg, data)