in grappler/auto_mixed_precision.cc [1284:1439]
Status AutoMixedPrecisionImpl::Optimize() {
string optimization_level;
TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
optimization_level = absl::AsciiStrToUpper(optimization_level);
force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) {
// Many ops do not support bfloat16 on the CPU so we disallowing forcing to
// bfloat16.
return errors::InvalidArgument(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
"UNSAFE_FORCE_ALL when MKL is used");
}
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
get_mixed_precision_lists();
f16_allowlist_ = mp_lists->AllowList();
f16_denylist_ = mp_lists->DenyList();
f16_inferlist_ = mp_lists->InferList();
f16_clearlist_ = mp_lists->ClearList();
TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
f16_inferlist_, f16_clearlist_));
size_t timestamp = Env::Default()->NowMicros() / 1000;
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
VLOG(2) << "Identifying nodes that should be processed";
for (const NodeDef& node : graph_->node()) {
bool should_process;
switch (mode_) {
case AutoMixedPrecisionMode::CUDA:
should_process =
!MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
(ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
break;
case AutoMixedPrecisionMode::MKL:
should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
break;
case AutoMixedPrecisionMode::NEURON:
should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
break;
}
if (should_process) {
should_process_nodes_.insert(&node);
} else {
LogSkippedNode(node);
}
}
VLOG(2) << "Converting FusedBatchNorm* ops to V2";
ConvertBatchNormOpsToV2();
VLOG(2) << "Building node type map for graph";
TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
VLOG(2) << "Constructing graph type attribute topology view";
TF_RETURN_IF_ERROR(
graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));
absl::flat_hash_set<int> deny_set;
std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
&deny_set);
std::vector<NodeTypeIdEdge> ephemeral_edges;
for (const auto& cluster : tensor_list_clusters) {
VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
for (const NodeDef* node : cluster) {
VLOG(2) << " Cluster member: " << node->op() << " node " << node->name();
}
FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
}
TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
// The goal here is to change performance-critical ops to fp16 or bf16, and to
// do so with the minimal number of casts, subject to the constraint that the
// model's convergence is not affected. This is achieved by first identifying
// which nodes should be changed to f16 and then inserting casts at the
// boundaries between f16/non-f16 nodes.
// The algorithm for deciding which nodes to change to f16 is as follows:
// 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
// This is done under the assumption that allowlist ops are always
// numerically-safe in f16 and that they are the most important ops for
// improving performance.
// 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
// "denylist" ops) or they are on a forward path from a denylist node to
// a deny/infer node (including the node at the end of the path) through
// non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
// This is done to prevent numerically-dangerous ops and their downstream
// effects from being changed to f16, which would risk breaking the
// numerical accuracy of the model.
// 3) For all remaining nodes that are not considered dangerous (inferlist
// and clearlist ops), find those that are between (i.e., both upstream
// and downstream of) allow nodes, and add them to the allow_set.
// This is done to avoid unnecessary casts between allowlist ops.
// 4) For all remaining clearlist nodes, add them to the allow_set if they are
// connected to a node in the allow_set via other clearlist nodes.
// This is done to increase the number of ops in the allow_set without
// affecting numerical stability.
absl::flat_hash_set<int> allow_set;
VLOG(2) << "Beginning pass 1 to add allowlist ops";
AddAllowlistOps(&allow_set);
VLOG(2) << "Finished pass 1";
if (allow_set.empty()) {
LOG(INFO) << "No allowlist ops found, nothing to do";
return Status::OK();
}
VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
"through clear/inferlist ops";
PropagateDenyFwdThroughClearAndInfer(&deny_set);
VLOG(2) << "Finished pass 2";
VLOG(2) << "Forcing color match between data structure ops";
for (const auto& cluster : tensor_list_clusters) {
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
}
VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
"are between allow ops";
AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
VLOG(2) << "Finished pass 3";
VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
"clearlist ops";
PropagateAllowThroughClear(deny_set, &allow_set);
VLOG(2) << "Finished pass 4";
VLOG(2) << "Beginning pass 5 to remove some nodes which could not be changed "
"to F16"
"from allow set";
RemoveAllowsetWithFp32(&allow_set);
VLOG(2) << "Finished pass 5";
VLOG(2) << "Forcing color match between data structure ops";
for (const auto& cluster : tensor_list_clusters) {
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
}
VLOG(2) << "Forcing color match on loop edges";
TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));
VLOG(2) << "Finding existing casts that can be made allow";
MakeCastsAllowIfAllOutputsAllow(&allow_set);
VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
"ops at paint boundaries";
TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
VLOG(2) << "Finished final pass";
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
return Status::OK();
}