in grappler/auto_mixed_precision.cc [1925:1999]
Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
const absl::flat_hash_set<int>& allow_set) {
int num_nodes_changed = 0;
int num_nonvar_casts_to_f16 = 0;
int num_nodes_preop = graph_->node_size();
for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
NodeDef* node = graph_->mutable_node(node_idx);
for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
const absl::optional<int> maybe_node_type_idx =
graph_type_view_.GetNodeIndex(node->name(), type_attr);
if (!maybe_node_type_idx.has_value()) {
return errors::Internal("Type attribute ", type_attr.DebugString(),
" of ", node->op(), " node ", node->name(),
" not found in graph view");
}
int node_type_idx = maybe_node_type_idx.value();
if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
bool src_is_allow = allow_set.count(node_type_idx);
if (src_is_allow) {
VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
<< node->op() << " node " << node->name() << " to "
<< DataTypeString(target_dtype_);
if (!SetDataType(node, type_attr, target_dtype_)) {
return errors::Internal("Failed to set type attribute");
}
++num_nodes_changed;
}
for (int output_port : node_type_map_.GetOutputPorts(*node, type_attr)) {
MutableGraphView::OutputPort src(node, output_port);
NodeDef* added_cast_node = nullptr;
// Note: This is copied so that edges can be modified inside the loop.
auto fanout = graph_view_.GetFanout(src);
for (const MutableGraphView::InputPort& dst : fanout) {
TypeAttrId dst_type_attr =
node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
const absl::optional<int> maybe_dst_type_idx =
graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
if (!maybe_dst_type_idx.has_value()) {
return errors::Internal("Type attribute ",
dst_type_attr.DebugString(), " of ",
dst.node->op(), " node ", dst.node->name(),
" not found in graph view");
}
int dst_type_idx = maybe_dst_type_idx.value();
bool dst_is_allow = allow_set.count(dst_type_idx);
if (src_is_allow != dst_is_allow) {
if (!added_cast_node) {
bool to_f16 = dst_is_allow;
VLOG(1) << "Inserting cast to "
<< (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
<< " at " << src.node->op() << " " << src.node->name()
<< ":" << src.port_id;
added_cast_node = graph_view_.AddNode(
BuildCastNode(src, to_f16, node));
if (to_f16 && !IsConstant(*node) && !IsVariable(*node) &&
!NodeImplicitlyReadsNonResourceVariable(*node)) {
++num_nonvar_casts_to_f16;
}
}
TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
}
}
}
}
}
// Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
// since many Python users will see this message.
const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
<< " nodes to " << type_str << " precision using "
<< num_nonvar_casts_to_f16 << " cast(s) to " << type_str
<< " (excluding Const and Variable casts)";
return Status::OK();
}