in ngraph_bridge/mark_for_clustering.cc [568:722]
Status MarkForClustering(Graph* graph,
const std::set<string> skip_these_nodes) {
const TypeConstraintMap& type_constraint_map = GetTypeConstraintMap();
// confirmation_function_map is non-const unlike the other maps
static std::map<std::string, ConfirmationFunction> confirmation_function_map =
GetConfirmationMap();
const std::map<std::string, SetAttributesFunction>& set_attributes_map =
GetAttributeSetters();
//
// IF YOU ARE ADDING A NEW OP IMPLEMENTATION, YOU MUST ADD A CONFIRMATION
// FUNCTION, TYPE CONTRAINTS (IF ANY) AND STATIC INPUTS INDEXES (IF ANY) FOR
// THE OP HERE.
// The constraint function should refuse placement if the node is not
// supported in the builder, and tag the node with any data that will be
// needed in case the graph is broken up in a later rewrite pass (for example,
// constant data).
static std::set<string> disabled_ops_set = {};
static bool initialized = false;
std::set<string> disabled_ops_set_current = api::GetDisabledOps();
bool op_set_support_has_changed =
disabled_ops_set_current != disabled_ops_set;
if (!initialized || op_set_support_has_changed) {
confirmation_function_map = GetConfirmationMap();
initialized = true;
}
if (op_set_support_has_changed) {
NGRAPH_VLOG(5) << "Changing op support";
disabled_ops_set = disabled_ops_set_current;
for (auto itr : disabled_ops_set) {
auto conf_itr = confirmation_function_map.find(itr);
if (conf_itr == confirmation_function_map.end()) {
// Note: This error means, we cannot disable NGraphEncapsulate and other
// ng ops, because they are expected to never appear in
// confirmation_function_map
return errors::Internal("Tried to disable ngraph unsupported op ", itr);
} else {
NGRAPH_VLOG(5) << "Disabling op: " << itr;
confirmation_function_map.erase(conf_itr);
}
}
}
std::unordered_map<string, int> no_support_histogram;
std::unordered_map<string, int> fail_confirmation_histogram;
std::unordered_map<string, int> fail_constraint_histogram;
vector<Node*> nodes_marked_for_clustering;
shared_ptr<Backend> op_backend = BackendManager::GetBackend();
for (auto node : graph->op_nodes()) {
bool mark_for_clustering = false;
do {
// check if output node
bool skip_it = false;
TF_RETURN_IF_ERROR(CheckIfOutputNode(node, skip_these_nodes, skip_it));
if (skip_it) {
NGRAPH_VLOG(5) << "NGTF_OPTIMIZER: Found Output Node: " << node->name()
<< " - skip marking it for clustering";
break;
}
// check placement
bool placement_ok = false;
TF_RETURN_IF_ERROR(NGraphPlacementRequested(node, placement_ok));
if (!placement_ok) {
NGRAPH_VLOG(5) << "Placement not requested: " << node->name();
break;
}
// check node's confirmation constraints
bool confirmation_constraint_ok = false;
TF_RETURN_IF_ERROR(ConfirmationOk(node, confirmation_function_map,
confirmation_constraint_ok));
if (!confirmation_constraint_ok) {
NGRAPH_VLOG(5) << "Node does not meet confirmation constraints: "
<< node->name();
if (confirmation_function_map.find(node->type_string()) ==
confirmation_function_map.end()) {
// not found
no_support_histogram[node->type_string()]++;
} else {
// found
fail_confirmation_histogram[node->type_string()]++;
}
break;
}
// check input type constraints
bool type_constraint_ok = false;
TF_RETURN_IF_ERROR(
TypeConstraintOk(node, type_constraint_map, type_constraint_ok));
if (!type_constraint_ok) {
NGRAPH_VLOG(5) << "Inputs do not meet type constraints: "
<< node->name();
fail_constraint_histogram[node->type_string()]++;
break;
}
// Check if op is supported by backend
bool is_supported = op_backend->IsSupported(node->type_string().c_str());
if (!is_supported) {
NGRAPH_VLOG(5) << "TF Op " << node->name() << " of type "
<< node->type_string()
<< " is not supported by backend: "
<< op_backend->Name();
break;
}
// if all constraints are met, mark for clustering
mark_for_clustering = true;
} while (false);
// Set the _ngraph_marked_for_clustering attribute if all constraints
// are satisfied
if (mark_for_clustering) {
NGRAPH_VLOG(4) << "Accepting: " << node->name() << "["
<< node->type_string() << "]";
nodes_marked_for_clustering.push_back(node);
} else {
NGRAPH_VLOG(4) << "Rejecting: " << node->name() << "["
<< node->type_string() << "]";
}
}
if (api::IsLoggingPlacement()) {
std::cout << "\n=============New sub-graph logs=============\n";
// print summary for nodes failed to be marked
std::cout << "NGTF_SUMMARY: Op_not_supported: ";
tf_utils::PrintNodeHistogram(no_support_histogram);
std::cout << "NGTF_SUMMARY: Op_failed_confirmation: ";
tf_utils::PrintNodeHistogram(fail_confirmation_histogram);
std::cout << "NGTF_SUMMARY: Op_failed_type_constraint: ";
tf_utils::PrintNodeHistogram(fail_constraint_histogram);
}
for (auto node : nodes_marked_for_clustering) {
// TODO(amprocte): move attr name to a constant
node->AddAttr("_ngraph_marked_for_clustering", true);
auto it = set_attributes_map.find(node->type_string());
if (it != set_attributes_map.end()) {
TF_RETURN_IF_ERROR(it->second(node));
}
}
return Status::OK();
}