tzrec/acc/export_utils.py (182 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from typing import Any, Dict, List, Tuple, Type
import torch
import torch._prims_common as prims_utils
import torch.nn.functional as F
from torch import nn
from torch._decomp import decomposition_table, register_decomposition
from torch._export.verifier import (
OpOverload,
SpecViolationError,
Verifier,
_check_val,
getattr_recursive,
is_functional,
)
from torch._prims_common.wrappers import out_wrapper
from torch.export import Dim
from tzrec.acc.utils import is_trt
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import logger
# add new aten._softmax decomposition which is supported by dynamo
aten = torch._ops.ops.aten
if aten._softmax.default in decomposition_table:
del decomposition_table[aten._softmax.default]
del decomposition_table[aten._softmax.out]
# pyre-ignore [56]
@register_decomposition(aten._softmax)
@out_wrapper()
def _softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor:
# eager softmax returns a contiguous tensor. Ensure that decomp also returns
# a contiguous tensor.
x = x.contiguous()
if half_to_float:
assert x.dtype == torch.half
computation_dtype, result_dtype = prims_utils.elementwise_dtypes(
x, type_promotion_kind=prims_utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
x = x.to(computation_dtype)
x_max = torch.max(x, dim, keepdim=True).values
unnormalized = torch.exp(x - x_max)
result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
if not half_to_float:
result = result.to(result_dtype)
return result
# now dynamo generate sym_int5 = sym_sum(sym_int1, sym_int2, sym_int3) op
# instead of sym_int4 = sym_int1 + sym_int2; sym_int5 = sym_int4 + sym_int3.
# patch _check_graph_module of Verifier temporarily to add sym_sum into allowed ops
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
def _allowed_getattr_types() -> Tuple[Type[Any], ...]:
ret = self.allowed_getattr_types()
assert not any(t is object for t in ret)
return ret
def _check_valid_op(op) -> None:
def _allowed_builtin_ops() -> List:
ret = self.allowed_builtin_ops()
assert all(inspect.isbuiltin(op) for op in ret)
return ret
def _allowed_op_types() -> Tuple[Type[Any], ...]:
ret = self.allowed_op_types()
assert not any(t is object for t in ret)
return ret
# TODO Remove this allowlist.
_allowed_torch_functions = (
torch.autograd.grad_mode.set_grad_enabled,
torch.sym_sum,
torch.sym_int,
torch.sym_float,
torch.sym_ite,
torch.sym_max,
torch.sym_min,
torch.sym_not,
torch.sym_sqrt,
# TODO (tmanlaibaatar)
# Predispatch export is able to contain autograd ops.
# These will be modeled as HOO later
torch._C._set_grad_enabled,
torch.amp.autocast_mode._enter_autocast,
torch.amp.autocast_mode._exit_autocast,
torch.fx.experimental.symbolic_shapes.cast_symbool_to_symint_guardless,
)
if not isinstance(op, _allowed_op_types()):
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
raise SpecViolationError(
f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" # NOQA
f"Valid builtin ops: {_allowed_builtin_ops()}"
f"Valid torch functions: {_allowed_torch_functions}"
)
if isinstance(op, OpOverload):
# All ops functional
# TODO (tmanlaibaatar) more proper way is needed here
if self.dialect != "TRAINING" and not is_functional(op):
raise SpecViolationError(f"operator '{op}' is not functional")
self.check_valid_op(op)
for mod in gm.modules():
if not isinstance(mod, torch.fx.GraphModule):
continue
mod.graph.lint()
for node in mod.graph.nodes:
# TODO(T140410192): should have fake tensor for all dialects
if node.op in {"call_module", "call_method"}:
raise SpecViolationError(
f"call_module is not valid: got a class '{node.target}' ",
)
elif node.op == "call_function":
_check_val(node)
_check_valid_op(node.target)
elif node.op == "get_attr":
if not isinstance(node.target, str):
raise SpecViolationError(
f"Expected get_attr target to be string, but got {type(node.target)}" # NOQA
)
attr = getattr_recursive(mod, node.target)
if isinstance(attr, torch.nn.Module):
def _is_type(name, ty):
return isinstance(getattr(attr, name, None), ty) # NOQA
if type(attr).__name__ == "LoweredBackendModule":
if (
_is_type("backend_id", str)
and _is_type("processed_bytes", bytes)
and _is_type("compile_specs", list)
and hasattr(attr, "original_module")
):
continue
else:
backend_id = getattr(attr, "backend_id", None)
processed_bytes = getattr(attr, "processed_bytes", None)
compile_specs = getattr(attr, "compile_specs", None)
raise SpecViolationError(
f"Invalid get_attr type {type(attr)}. \n"
f"LoweredBackendModule fields: "
f"backend_id(str) : {type(backend_id)}, "
f"processed_bytes(bytes) : {type(processed_bytes)}, "
f"compile_specs(list) : {type(compile_specs)}"
)
if not isinstance(attr, _allowed_getattr_types()):
raise SpecViolationError(
f"Invalid get_attr type {type(attr)}. \n"
f"Valid get_attr types: {_allowed_getattr_types()}"
)
elif node.op == "placeholder":
_check_val(node)
# TODO(zhxchen17)
# elif node.op == "output":
# _check_flattened_outputs()
self.check_additional(gm)
Verifier._check_graph_module = _check_graph_module
def export_pm(
model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str
) -> Tuple[torch.export.ExportedProgram, Dict[str, torch.Tensor]]:
"""Export a PyTorch model and its parameters.
Args:
model (nn.Module): The PyTorch model to export.
data (Dict[str, torch.Tensor]): containing the model's input tensors.
save_dir (str): The directory where the model should be saved.
Returns:
Tuple[torch.export.ExportedProgram, Dict[str, torch.Tensor]]:
The exported program and its input data.
"""
gm = symbolic_trace(model)
with open(os.path.join(save_dir, "gm.code"), "w") as f:
f.write(gm.code)
gm = gm.cuda()
batch = Dim("batch")
dynamic_shapes = {}
for key in data:
# .lengths
if key.endswith(".lengths"):
# user feats
if key.split(".")[0] in model._data_parser.user_feats:
assert data[key].shape[0] == 1
logger.info(
"uniq user length fea %s length=%s" % (key, data[key].shape)
)
dynamic_shapes[key] = {}
else:
logger.info("batch length fea=%s shape=%s" % (key, data[key].shape))
dynamic_shapes[key] = {0: batch}
elif key == "batch_size":
dynamic_shapes[key] = {}
# dense values
elif key.split(".")[0] in model._data_parser.dense_keys_list:
# user feats
if key.split(".")[0] in model._data_parser.user_feats:
assert data[key].shape[0] == 1
logger.info("uniq user dense_fea=%s shape=%s" % (key, data[key].shape))
dynamic_shapes[key] = {}
else:
logger.info("batch dense_fea=%s shape=%s" % (key, data[key].shape))
dynamic_shapes[key] = {0: batch}
# seq_dense or sparse
else:
if data[key].shape[0] < 2:
data[key] = F.pad(
data[key],
[0, 0] * (len(data[key].shape) - 1) + [0, 2],
mode="constant",
)
data[key.split(".")[0] + ".lengths"][0] = data[key].shape[0]
logger.info("sparse or seq dense fea=%s shape=%s" % (key, data[key].shape))
tmp_val_dim = Dim(key.replace(".", "__") + "__batch", min=0)
dynamic_shapes[key] = {0: tmp_val_dim}
# trt need contiguous format
if is_trt():
data[key] = data[key].contiguous()
logger.info("dynamic shapes=%s" % dynamic_shapes)
exported_pg = torch.export.export(
gm, args=(data,), dynamic_shapes=(dynamic_shapes,)
)
export_path = os.path.join(save_dir, "exported_pg.py")
with open(export_path, "w") as fout:
fout.write(str(exported_pg))
exported_pg.module()(data)
return (exported_pg, data)