Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts()

in grappler/auto_mixed_precision.cc [1925:1999]


Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
    const absl::flat_hash_set<int>& allow_set) {
  int num_nodes_changed = 0;
  int num_nonvar_casts_to_f16 = 0;
  int num_nodes_preop = graph_->node_size();
  for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
    NodeDef* node = graph_->mutable_node(node_idx);
    for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
      const absl::optional<int> maybe_node_type_idx =
          graph_type_view_.GetNodeIndex(node->name(), type_attr);
      if (!maybe_node_type_idx.has_value()) {
        return errors::Internal("Type attribute ", type_attr.DebugString(),
                                " of ", node->op(), " node ", node->name(),
                                " not found in graph view");
      }
      int node_type_idx = maybe_node_type_idx.value();
      if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
      bool src_is_allow = allow_set.count(node_type_idx);
      if (src_is_allow) {
        VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
                << node->op() << " node " << node->name() << " to "
                << DataTypeString(target_dtype_);
        if (!SetDataType(node, type_attr, target_dtype_)) {
          return errors::Internal("Failed to set type attribute");
        }
        ++num_nodes_changed;
      }
      for (int output_port : node_type_map_.GetOutputPorts(*node, type_attr)) {
        MutableGraphView::OutputPort src(node, output_port);
        NodeDef* added_cast_node = nullptr;
        // Note: This is copied so that edges can be modified inside the loop.
        auto fanout = graph_view_.GetFanout(src);
        for (const MutableGraphView::InputPort& dst : fanout) {
          TypeAttrId dst_type_attr =
              node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
          const absl::optional<int> maybe_dst_type_idx =
              graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
          if (!maybe_dst_type_idx.has_value()) {
            return errors::Internal("Type attribute ",
                                    dst_type_attr.DebugString(), " of ",
                                    dst.node->op(), " node ", dst.node->name(),
                                    " not found in graph view");
          }
          int dst_type_idx = maybe_dst_type_idx.value();
          bool dst_is_allow = allow_set.count(dst_type_idx);
          if (src_is_allow != dst_is_allow) {
            if (!added_cast_node) {
              bool to_f16 = dst_is_allow;
              VLOG(1) << "Inserting cast to "
                      << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
                      << " at " << src.node->op() << " " << src.node->name()
                      << ":" << src.port_id;
              added_cast_node = graph_view_.AddNode(
                  BuildCastNode(src, to_f16, node));
              if (to_f16 && !IsConstant(*node) && !IsVariable(*node) &&
                  !NodeImplicitlyReadsNonResourceVariable(*node)) {
                ++num_nonvar_casts_to_f16;
              }
            }
            TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
                dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
          }
        }
      }
    }
  }
  // Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
  // since many Python users will see this message.
  const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
  LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
            << " nodes to " << type_str << " precision using "
            << num_nonvar_casts_to_f16 << " cast(s) to " << type_str
            << " (excluding Const and Variable casts)";
  return Status::OK();
}