in tzrec/acc/export_utils.py [0:0]
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
def _allowed_getattr_types() -> Tuple[Type[Any], ...]:
ret = self.allowed_getattr_types()
assert not any(t is object for t in ret)
return ret
def _check_valid_op(op) -> None:
def _allowed_builtin_ops() -> List:
ret = self.allowed_builtin_ops()
assert all(inspect.isbuiltin(op) for op in ret)
return ret
def _allowed_op_types() -> Tuple[Type[Any], ...]:
ret = self.allowed_op_types()
assert not any(t is object for t in ret)
return ret
# TODO Remove this allowlist.
_allowed_torch_functions = (
torch.autograd.grad_mode.set_grad_enabled,
torch.sym_sum,
torch.sym_int,
torch.sym_float,
torch.sym_ite,
torch.sym_max,
torch.sym_min,
torch.sym_not,
torch.sym_sqrt,
# TODO (tmanlaibaatar)
# Predispatch export is able to contain autograd ops.
# These will be modeled as HOO later
torch._C._set_grad_enabled,
torch.amp.autocast_mode._enter_autocast,
torch.amp.autocast_mode._exit_autocast,
torch.fx.experimental.symbolic_shapes.cast_symbool_to_symint_guardless,
)
if not isinstance(op, _allowed_op_types()):
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
raise SpecViolationError(
f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" # NOQA
f"Valid builtin ops: {_allowed_builtin_ops()}"
f"Valid torch functions: {_allowed_torch_functions}"
)
if isinstance(op, OpOverload):
# All ops functional
# TODO (tmanlaibaatar) more proper way is needed here
if self.dialect != "TRAINING" and not is_functional(op):
raise SpecViolationError(f"operator '{op}' is not functional")
self.check_valid_op(op)
for mod in gm.modules():
if not isinstance(mod, torch.fx.GraphModule):
continue
mod.graph.lint()
for node in mod.graph.nodes:
# TODO(T140410192): should have fake tensor for all dialects
if node.op in {"call_module", "call_method"}:
raise SpecViolationError(
f"call_module is not valid: got a class '{node.target}' ",
)
elif node.op == "call_function":
_check_val(node)
_check_valid_op(node.target)
elif node.op == "get_attr":
if not isinstance(node.target, str):
raise SpecViolationError(
f"Expected get_attr target to be string, but got {type(node.target)}" # NOQA
)
attr = getattr_recursive(mod, node.target)
if isinstance(attr, torch.nn.Module):
def _is_type(name, ty):
return isinstance(getattr(attr, name, None), ty) # NOQA
if type(attr).__name__ == "LoweredBackendModule":
if (
_is_type("backend_id", str)
and _is_type("processed_bytes", bytes)
and _is_type("compile_specs", list)
and hasattr(attr, "original_module")
):
continue
else:
backend_id = getattr(attr, "backend_id", None)
processed_bytes = getattr(attr, "processed_bytes", None)
compile_specs = getattr(attr, "compile_specs", None)
raise SpecViolationError(
f"Invalid get_attr type {type(attr)}. \n"
f"LoweredBackendModule fields: "
f"backend_id(str) : {type(backend_id)}, "
f"processed_bytes(bytes) : {type(processed_bytes)}, "
f"compile_specs(list) : {type(compile_specs)}"
)
if not isinstance(attr, _allowed_getattr_types()):
raise SpecViolationError(
f"Invalid get_attr type {type(attr)}. \n"
f"Valid get_attr types: {_allowed_getattr_types()}"
)
elif node.op == "placeholder":
_check_val(node)
# TODO(zhxchen17)
# elif node.op == "output":
# _check_flattened_outputs()
self.check_additional(gm)