Status AutoMixedPrecisionImpl::Optimize()

in grappler/auto_mixed_precision.cc [1284:1439]


Status AutoMixedPrecisionImpl::Optimize() {
  string optimization_level;
  TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
      "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
  optimization_level = absl::AsciiStrToUpper(optimization_level);
  force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
  if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) {
    // Many ops do not support bfloat16 on the CPU so we disallowing forcing to
    // bfloat16.
    return errors::InvalidArgument(
        "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
        "UNSAFE_FORCE_ALL when MKL is used");
  }

  std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
      get_mixed_precision_lists();
  f16_allowlist_ = mp_lists->AllowList();
  f16_denylist_ = mp_lists->DenyList();
  f16_inferlist_ = mp_lists->InferList();
  f16_clearlist_ = mp_lists->ClearList();
  TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
                                   f16_inferlist_, f16_clearlist_));
  size_t timestamp = Env::Default()->NowMicros() / 1000;
  TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));

  VLOG(2) << "Identifying nodes that should be processed";
  for (const NodeDef& node : graph_->node()) {
    bool should_process;
    switch (mode_) {
      case AutoMixedPrecisionMode::CUDA:
        should_process =
            !MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
            (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
        break;
      case AutoMixedPrecisionMode::MKL:
        should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
        break;
      case AutoMixedPrecisionMode::NEURON:
        should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
        break;
    }
    if (should_process) {
      should_process_nodes_.insert(&node);
    } else {
      LogSkippedNode(node);
    }
  }

  VLOG(2) << "Converting FusedBatchNorm* ops to V2";
  ConvertBatchNormOpsToV2();

  VLOG(2) << "Building node type map for graph";
  TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));

  VLOG(2) << "Constructing graph type attribute topology view";
  TF_RETURN_IF_ERROR(
      graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));

  absl::flat_hash_set<int> deny_set;

  std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
  FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
                                                   &deny_set);
  std::vector<NodeTypeIdEdge> ephemeral_edges;
  for (const auto& cluster : tensor_list_clusters) {
    VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
    for (const NodeDef* node : cluster) {
      VLOG(2) << "  Cluster member: " << node->op() << " node " << node->name();
    }
    FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
  }
  TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));

  // The goal here is to change performance-critical ops to fp16 or bf16, and to
  // do so with the minimal number of casts, subject to the constraint that the
  // model's convergence is not affected. This is achieved by first identifying
  // which nodes should be changed to f16 and then inserting casts at the
  // boundaries between f16/non-f16 nodes.

  // The algorithm for deciding which nodes to change to f16 is as follows:
  // 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
  //    This is done under the assumption that allowlist ops are always
  //    numerically-safe in f16 and that they are the most important ops for
  //    improving performance.
  // 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
  //    "denylist" ops) or they are on a forward path from a denylist node to
  //    a deny/infer node (including the node at the end of the path) through
  //    non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
  //    This is done to prevent numerically-dangerous ops and their downstream
  //    effects from being changed to f16, which would risk breaking the
  //    numerical accuracy of the model.
  // 3) For all remaining nodes that are not considered dangerous (inferlist
  //    and clearlist ops), find those that are between (i.e., both upstream
  //    and downstream of) allow nodes, and add them to the allow_set.
  //    This is done to avoid unnecessary casts between allowlist ops.
  // 4) For all remaining clearlist nodes, add them to the allow_set if they are
  //    connected to a node in the allow_set via other clearlist nodes.
  //    This is done to increase the number of ops in the allow_set without
  //    affecting numerical stability.

  absl::flat_hash_set<int> allow_set;
  VLOG(2) << "Beginning pass 1 to add allowlist ops";
  AddAllowlistOps(&allow_set);
  VLOG(2) << "Finished pass 1";

  if (allow_set.empty()) {
    LOG(INFO) << "No allowlist ops found, nothing to do";
    return Status::OK();
  }

  VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
             "through clear/inferlist ops";
  PropagateDenyFwdThroughClearAndInfer(&deny_set);
  VLOG(2) << "Finished pass 2";

  VLOG(2) << "Forcing color match between data structure ops";
  for (const auto& cluster : tensor_list_clusters) {
    ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
  }

  VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
             "are between allow ops";
  AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
  VLOG(2) << "Finished pass 3";

  VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
             "clearlist ops";
  PropagateAllowThroughClear(deny_set, &allow_set);
  VLOG(2) << "Finished pass 4";

  VLOG(2) << "Beginning pass 5 to remove some nodes which could not be changed "
             "to F16"
             "from allow set";
  RemoveAllowsetWithFp32(&allow_set);
  VLOG(2) << "Finished pass 5";

  VLOG(2) << "Forcing color match between data structure ops";
  for (const auto& cluster : tensor_list_clusters) {
    ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
  }

  VLOG(2) << "Forcing color match on loop edges";
  TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));

  VLOG(2) << "Finding existing casts that can be made allow";
  MakeCastsAllowIfAllOutputsAllow(&allow_set);

  VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
             "ops at paint boundaries";
  TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
  VLOG(2) << "Finished final pass";

  TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));

  return Status::OK();
}