Status MarkForClustering()

in ngraph_bridge/mark_for_clustering.cc [568:722]


Status MarkForClustering(Graph* graph,
                         const std::set<string> skip_these_nodes) {
  const TypeConstraintMap& type_constraint_map = GetTypeConstraintMap();

  // confirmation_function_map is non-const unlike the other maps
  static std::map<std::string, ConfirmationFunction> confirmation_function_map =
      GetConfirmationMap();

  const std::map<std::string, SetAttributesFunction>& set_attributes_map =
      GetAttributeSetters();

  //
  // IF YOU ARE ADDING A NEW OP IMPLEMENTATION, YOU MUST ADD A CONFIRMATION
  // FUNCTION, TYPE CONTRAINTS (IF ANY) AND STATIC INPUTS INDEXES (IF ANY) FOR
  // THE OP HERE.

  // The constraint function should refuse placement if the node is not
  // supported in the builder, and tag the node with any data that will be
  // needed in case the graph is broken up in a later rewrite pass (for example,
  // constant data).

  static std::set<string> disabled_ops_set = {};

  static bool initialized = false;

  std::set<string> disabled_ops_set_current = api::GetDisabledOps();
  bool op_set_support_has_changed =
      disabled_ops_set_current != disabled_ops_set;

  if (!initialized || op_set_support_has_changed) {
    confirmation_function_map = GetConfirmationMap();
    initialized = true;
  }

  if (op_set_support_has_changed) {
    NGRAPH_VLOG(5) << "Changing op support";
    disabled_ops_set = disabled_ops_set_current;
    for (auto itr : disabled_ops_set) {
      auto conf_itr = confirmation_function_map.find(itr);
      if (conf_itr == confirmation_function_map.end()) {
        // Note: This error means, we cannot disable NGraphEncapsulate and other
        // ng ops, because they are expected to never appear in
        // confirmation_function_map
        return errors::Internal("Tried to disable ngraph unsupported op ", itr);
      } else {
        NGRAPH_VLOG(5) << "Disabling op: " << itr;
        confirmation_function_map.erase(conf_itr);
      }
    }
  }

  std::unordered_map<string, int> no_support_histogram;
  std::unordered_map<string, int> fail_confirmation_histogram;
  std::unordered_map<string, int> fail_constraint_histogram;
  vector<Node*> nodes_marked_for_clustering;

  shared_ptr<Backend> op_backend = BackendManager::GetBackend();
  for (auto node : graph->op_nodes()) {
    bool mark_for_clustering = false;

    do {
      // check if output node
      bool skip_it = false;
      TF_RETURN_IF_ERROR(CheckIfOutputNode(node, skip_these_nodes, skip_it));
      if (skip_it) {
        NGRAPH_VLOG(5) << "NGTF_OPTIMIZER: Found Output Node: " << node->name()
                       << " - skip marking it for clustering";
        break;
      }

      // check placement
      bool placement_ok = false;
      TF_RETURN_IF_ERROR(NGraphPlacementRequested(node, placement_ok));
      if (!placement_ok) {
        NGRAPH_VLOG(5) << "Placement not requested: " << node->name();
        break;
      }

      // check node's confirmation constraints
      bool confirmation_constraint_ok = false;
      TF_RETURN_IF_ERROR(ConfirmationOk(node, confirmation_function_map,
                                        confirmation_constraint_ok));
      if (!confirmation_constraint_ok) {
        NGRAPH_VLOG(5) << "Node does not meet confirmation constraints: "
                       << node->name();
        if (confirmation_function_map.find(node->type_string()) ==
            confirmation_function_map.end()) {
          // not found
          no_support_histogram[node->type_string()]++;
        } else {
          // found
          fail_confirmation_histogram[node->type_string()]++;
        }
        break;
      }

      // check input type constraints
      bool type_constraint_ok = false;
      TF_RETURN_IF_ERROR(
          TypeConstraintOk(node, type_constraint_map, type_constraint_ok));
      if (!type_constraint_ok) {
        NGRAPH_VLOG(5) << "Inputs do not meet type constraints: "
                       << node->name();
        fail_constraint_histogram[node->type_string()]++;
        break;
      }

      // Check if op is supported by backend
      bool is_supported = op_backend->IsSupported(node->type_string().c_str());
      if (!is_supported) {
        NGRAPH_VLOG(5) << "TF Op " << node->name() << " of type "
                       << node->type_string()
                       << " is not supported by backend: "
                       << op_backend->Name();
        break;
      }

      // if all constraints are met, mark for clustering
      mark_for_clustering = true;
    } while (false);

    // Set the _ngraph_marked_for_clustering attribute if all constraints
    // are satisfied
    if (mark_for_clustering) {
      NGRAPH_VLOG(4) << "Accepting: " << node->name() << "["
                     << node->type_string() << "]";
      nodes_marked_for_clustering.push_back(node);
    } else {
      NGRAPH_VLOG(4) << "Rejecting: " << node->name() << "["
                     << node->type_string() << "]";
    }
  }

  if (api::IsLoggingPlacement()) {
    std::cout << "\n=============New sub-graph logs=============\n";
    // print summary for nodes failed to be marked
    std::cout << "NGTF_SUMMARY: Op_not_supported: ";
    tf_utils::PrintNodeHistogram(no_support_histogram);
    std::cout << "NGTF_SUMMARY: Op_failed_confirmation: ";
    tf_utils::PrintNodeHistogram(fail_confirmation_histogram);
    std::cout << "NGTF_SUMMARY: Op_failed_type_constraint: ";
    tf_utils::PrintNodeHistogram(fail_constraint_histogram);
  }

  for (auto node : nodes_marked_for_clustering) {
    // TODO(amprocte): move attr name to a constant
    node->AddAttr("_ngraph_marked_for_clustering", true);
    auto it = set_attributes_map.find(node->type_string());
    if (it != set_attributes_map.end()) {
      TF_RETURN_IF_ERROR(it->second(node));
    }
  }

  return Status::OK();
}