in ngraph_bridge/assign_clusters.cc [295:750]
Status AssignClusters(Graph* graph) {
std::map<Node*, std::shared_ptr<Cluster>> cluster_map;
#if !defined(NGRAPH_TF_DISABLE_DEADNESS_CHECK)
std::unique_ptr<DeadnessAnalysis> deadness_analyzer;
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph, &deadness_analyzer));
// This map is used only for error checking
std::map<Node*, std::string> nodes_predicate_map;
#endif
GraphCycles gc;
// Initial Step: Each node is a cluster of its own
for (auto node : graph->nodes()) {
int new_index = gc.NewNode();
cluster_map[node] = std::make_shared<Cluster>();
cluster_map[node]->index = new_index;
cluster_map[node]->nodes.insert(node);
NGRAPH_VLOG(5) << "Creating graphcycle Node: " << new_index << " for "
<< node->name() << "[" << node->type_string() << "]";
#if !defined(NGRAPH_TF_DISABLE_DEADNESS_CHECK)
// get predicate string for the node
string pred_string;
TF_RETURN_IF_ERROR(deadness_analyzer->GetNodePredicate(*node, pred_string));
nodes_predicate_map[node] = pred_string;
cluster_map[node]->predicate_string = pred_string;
cluster_map[node]->outgoing_edges = std::set<const Edge*>(
node->out_edges().begin(), node->out_edges().end());
NGRAPH_VLOG(5) << node->name() << "[" << node->type_string() << "]"
<< " : Predicate " << pred_string;
#endif
}
// Check for existing cyclicity in the graph
for (auto edge : graph->edges()) {
Node* src = edge->src();
Node* dst = edge->dst();
// Skip source/sink
if (!src->IsOp() || !dst->IsOp()) {
continue;
}
// Skip NextIteration
if (src->IsNextIteration() || dst->IsNextIteration()) {
continue;
}
if (!gc.InsertEdge(cluster_map[src]->index, cluster_map[dst]->index)) {
NGRAPH_VLOG(5) << "Failing due to cycle";
return errors::Unimplemented(
"Input graph has a cycle (inserting an edge from ",
src->DebugString(), " to ", dst->DebugString(),
" would create a cycle)");
}
}
// If we wish to add a constraint that 2 particular nodes not lie in the same
// cluster, then all we have to do is add 2 'shadow' edges and 1 'shadow' node
// in the gc data structure between the 2 nodes. The shadow edges go from the
// node closer to toposort source to the node closer to sink, through a shadow
// node. src--->S--->dst. (not the other way round, else it would introduce a
// cycle).
// TF world node (o), gc world node (+), static input *
// Normal edge traslation:
// (o)---->(o) ==> (+)---->(+)
// Static input edge translation:
// (o)---->*(o) ==> (+)---->(+)
// | ^
// | |
// --(+)--
// The contraction only happens on 'real' edges (edges that are
// present in the TF graph itself). Therefore the shadow edges in the gc
// data structure will never suffer contraction. Anytime the shadow path's src
// and dst attempt a merge (by contracting some real edge between them),
// the shadow path will introduce a cycle and not allow it
// Warning: this relies on the fact that we attempt to contract 'real' edges
// from the TF graph. For optimization, one might attempt to contract the gc
// edges, which keep decreasing unlike the TF edges. But this fix would break
// then, since we have broken the contract that an edge in gc implies an edge
// in TF in this fix
for (auto node : graph->op_nodes()) {
std::vector<int32> static_inputs;
GetStaticInputs(node, &static_inputs);
if (static_inputs.size() > 0) {
std::vector<const Edge*> edges_to_node;
TF_RETURN_IF_ERROR(node->input_edges(&edges_to_node));
for (auto static_inp_idx : static_inputs) {
auto static_edge = edges_to_node[static_inp_idx];
if (static_edge->src()->type_string() != "Const") {
int shadow_node_index = gc.NewNode();
bool gc_success = gc.InsertEdge(
cluster_map[static_edge->src()]->index, shadow_node_index);
gc_success &= gc.InsertEdge(shadow_node_index,
cluster_map[static_edge->dst()]->index);
if (!gc_success)
return errors::Internal(
"Unable to create shadow edges in GraphCycles");
}
}
}
}
NGRAPH_VLOG(2) << "Starting contraction";
bool changed;
bool collect_non_contracting_edge_info = false; // Must init with false
// 6 exhaustive reasons why edges might non contract
// The reasons are not mutually exclusive, but there is an order of priority
// that makes them mutually exclusive
enum EdgeNonContractionReasons {
NOTANOP, // edge connects to non-ops
UNSUPPORTED, // either the src or dst is an unsupported op
DEADNESS, // deadness criteria not met
SAMECLUSTER, // both ends lie in the same cluster
STATICINPUT, // static input in dst (not fed by const)
PATHEXISTS // base case reason. contraction causes cycles
};
static std::vector<string> reason_string( // to convert the enum to string
{"NOTANOP", "UNSUPPORTED", "DEADNESS", "SAMECLUSTER", "STATICINPUT",
"PATHEXISTS"});
// a cluster pair is the string "cluster1_id, cluster2_id"
// Using string, because a pair won't hash unless implemented
// Note that we store a vector of "reasons", because there could be multiple
// reasons
using ClusterPairToReason =
std::unordered_map<std::string, std::vector<EdgeNonContractionReasons>>;
ClusterPairToReason cluster_separation_reason;
auto get_string_key = [](int x, int y) {
return to_string(x) + "," + to_string(y);
};
// (src id, dst id) -> (src predicate, dst predicate, other neighbours
// predicates)
std::unordered_map<std::string, tuple<string, string, vector<string>>>
deadness_info;
do {
changed = false;
auto log_reason = [](EdgeNonContractionReasons reason, Edge* edge) {
NGRAPH_VLOG(0) << "NONCONTRACTION: " << reason_string[reason] << ": "
<< edge->src()->name() << "<" << edge->src()->type_string()
<< ">"
<< "[" << edge->src_output() << "] -> "
<< edge->dst()->name() << "<" << edge->dst()->type_string()
<< ">"
<< "[" << edge->dst_input() << "]";
};
for (auto edge : graph->edges()) {
Node* src = edge->src();
Node* dst = edge->dst();
int src_index = cluster_map[src]->index;
int dst_index = cluster_map[dst]->index;
if (!src->IsOp() || !dst->IsOp()) {
if (collect_non_contracting_edge_info) {
log_reason(EdgeNonContractionReasons::NOTANOP, edge);
cluster_separation_reason[get_string_key(src_index, dst_index)]
.push_back(EdgeNonContractionReasons::NOTANOP);
}
continue;
}
if (!NodeIsMarkedForClustering(src) || !NodeIsMarkedForClustering(dst)) {
NGRAPH_VLOG(5) << "Skipping (not marked): " << src->name() << "["
<< edge->src_output() << "]@" << src_index << " -> "
<< dst->name() << "[" << edge->dst_input() << "]@"
<< dst_index;
if (collect_non_contracting_edge_info) {
log_reason(EdgeNonContractionReasons::UNSUPPORTED, edge);
cluster_separation_reason[get_string_key(src_index, dst_index)]
.push_back(EdgeNonContractionReasons::UNSUPPORTED);
}
continue;
}
#if !defined(NGRAPH_TF_DISABLE_DEADNESS_CHECK)
// check if the edge can be contracted with respect to deadness
bool is_deadness_ok = false;
TF_RETURN_IF_ERROR(
CanContractEdgeDeadnessCheck(edge, cluster_map, is_deadness_ok));
if (!is_deadness_ok) {
// do not contract, src and dst node cannot be in the same cluster
NGRAPH_VLOG(5) << "Skipping (deadness not ok): " << src->name() << "["
<< edge->src_output() << "]@" << src_index << " -> "
<< dst->name() << "[" << edge->dst_input() << "]@"
<< dst_index;
if (collect_non_contracting_edge_info) {
log_reason(EdgeNonContractionReasons::DEADNESS, edge);
cluster_separation_reason[get_string_key(src_index, dst_index)]
.push_back(EdgeNonContractionReasons::DEADNESS);
auto src_cluster = cluster_map[src];
auto dst_cluster = cluster_map[dst];
vector<string> neighbours_predicate;
// Collect predicates of src's neighbours (except dst)
for (const Edge* src_cluster_edge : src_cluster->outgoing_edges) {
if (src_cluster_edge != edge) {
neighbours_predicate.push_back(
cluster_map[src_cluster_edge->dst()]->predicate_string);
}
}
deadness_info[get_string_key(src_index, dst_index)] = make_tuple(
cluster_map.at(src)->predicate_string,
cluster_map.at(dst)->predicate_string, neighbours_predicate);
}
continue;
}
#endif
// Check if contracting the edge will lead to cycles
// if not, MergeClusters
if (gc.HasEdge(src_index, dst_index) &&
gc.ContractEdge(src_index, dst_index)) {
MergeClusters(edge, cluster_map);
// something changed
changed = true;
} else {
if (collect_non_contracting_edge_info) {
// either static input
// or there exists a longer path, so contracting this edge causes
// cycles
std::vector<int32> static_inputs;
GetStaticInputs(dst, &static_inputs);
bool is_static = std::find(static_inputs.begin(), static_inputs.end(),
edge->dst_input()) != static_inputs.end();
bool is_not_const = src->type_string() != "Const";
// 3 possible reasons here:
// src dst lies in same cluster, so nothing to do (trivial cycle
// induced in graphcycles)
// dst has static input
// a longer irreducible path exists
auto reason = (src_index == dst_index
? EdgeNonContractionReasons::SAMECLUSTER
: ((is_not_const && is_static)
? EdgeNonContractionReasons::STATICINPUT
: EdgeNonContractionReasons::PATHEXISTS));
log_reason(reason, edge);
cluster_separation_reason[get_string_key(src_index, dst_index)]
.push_back(reason);
}
}
}
if (!changed && api::IsLoggingPlacement()) {
// This will be entered only once if logging is enabled
// When entered, it will force the do-while to run one last time,
// collecting information
if (!collect_non_contracting_edge_info) {
changed = true;
collect_non_contracting_edge_info = true;
}
}
} while (changed);
NGRAPH_VLOG(2) << "Contraction done";
NGRAPH_VLOG(2) << "Starting tagging";
std::set<Cluster*> seen;
unordered_map<int, int> cluster_to_encapsulate;
for (auto kv : cluster_map) {
auto cluster = kv.second.get();
if (seen.count(cluster) != 0) {
continue;
}
bool has_ngraph_ops = false;
bool has_non_ngraph_ops = false;
for (auto node : cluster->nodes) {
if (NodeIsMarkedForClustering(node)) {
has_ngraph_ops = true;
// Some sanity checks for deadness
#if !defined(NGRAPH_TF_DISABLE_DEADNESS_CHECK)
TF_RETURN_IF_ERROR(CheckNodeClusterAssignmentWRTDeadness(
node, nodes_predicate_map, cluster_map));
#endif
} else {
has_non_ngraph_ops = true;
}
}
if (has_ngraph_ops && has_non_ngraph_ops) {
NGRAPH_VLOG(2) << "Cluster " << cluster->index
<< " has both nGraph and non-nGraph nodes";
for (auto node : cluster->nodes) {
NGRAPH_VLOG(2) << (NodeIsMarkedForClustering(node)
? "nGraph node: "
: "non-nGraph node: ")
<< node->name() << " [" << node->type_string() << "]";
}
return errors::Internal("Cluster ", cluster->index,
" has both nGraph and non-nGraph nodes");
}
if (!has_ngraph_ops) {
seen.insert(cluster);
continue;
}
size_t cluster_idx = ClusterManager::NewCluster();
for (auto node : cluster->nodes) {
if (NGRAPH_VLOG_IS_ON(5)) {
NGRAPH_VLOG(5) << ">> cluster " << cluster_idx << ": " << node->id()
<< " " << node << " :: " << node->name() << " ["
<< node->type_string() << "]";
}
if (!NodeIsMarkedForClustering(node)) {
return errors::Internal("Node ", node->DebugString(),
" was not marked for clustering but was "
"placed in an nGraph cluster.");
}
// TODO(amprocte): move attr name to a constant
node->AddAttr("_ngraph_cluster", (int)cluster_idx);
if (api::IsLoggingPlacement()) {
// map from cluster id to ngraph_cluster id
cluster_to_encapsulate[cluster->index] = cluster_idx;
}
}
seen.insert(cluster);
}
NGRAPH_VLOG(2) << "Tagging done";
if (api::IsLoggingPlacement()) {
int num_reasons = 6; // the number of elements in the reasons enum
// histogram of reasons of non-contraction of clusters
vector<int> reason_count_clusters(num_reasons, 0);
vector<int> reason_count_encapsulates(num_reasons, 0);
int num_non_contracted = 0;
auto forbidden_reasons_for_not_merging_clusters =
std::set<EdgeNonContractionReasons>{
EdgeNonContractionReasons::NOTANOP,
EdgeNonContractionReasons::UNSUPPORTED,
EdgeNonContractionReasons::SAMECLUSTER};
std::function<bool(EdgeNonContractionReasons)> is_forbidden_reason =
[&forbidden_reasons_for_not_merging_clusters](
EdgeNonContractionReasons r) -> bool {
return forbidden_reasons_for_not_merging_clusters.find(r) !=
forbidden_reasons_for_not_merging_clusters.end();
};
std::cout
<< "Encapsulate i->j: non contraction reason histogram (Cannot be "
"UNSUPPORTED, NOTANOP or SAMECLUSTER because unsupported ops will "
"not be "
"assigned an encapsulate)\n";
for (auto it : cluster_separation_reason) {
num_non_contracted += it.second.size();
std::vector<std::string> cluster_id_vector =
absl::StrSplit(it.first, ',');
// function to find if this cluster became an ngraph_cluster
// returns ngraph_cluster id if yes, else returns -1
auto find_in_map = [&cluster_to_encapsulate, &cluster_id_vector](int x) {
auto itr = cluster_to_encapsulate.find(stoi(cluster_id_vector[x]));
return itr == cluster_to_encapsulate.end() ? -1 : itr->second;
};
int src_encapsulate = find_in_map(0);
int dst_encapsulate = find_in_map(1);
bool both_src_dst_are_encapsulates =
src_encapsulate >= 0 && dst_encapsulate >= 0;
bool src_dst_are_distinct = src_encapsulate != dst_encapsulate;
vector<int> reason_count_encapsulates_for_pair(num_reasons, 0);
bool pair_has_reason = false;
string deadness_string = "";
for (auto& inner_itr : it.second) {
// This if checks if the pair are 2 distinct encapsulates
// In which case it asserts certain non-merging reasons are not possible
// And also prints the reasons of non-merging
if (both_src_dst_are_encapsulates && src_dst_are_distinct) {
if (is_forbidden_reason(inner_itr)) {
return errors::Internal(
inner_itr,
" should not be a reason why 2 encapsulates did not "
"merge, because unsupported ops would not end up in "
"encapsulates");
}
pair_has_reason = true;
reason_count_encapsulates_for_pair[inner_itr]++;
auto deadness_itr = deadness_info.find(it.first);
if (deadness_itr != deadness_info.end()) {
auto deadness_predicates_tpl = deadness_itr->second;
deadness_string +=
("Source[" + to_string(src_encapsulate) + "] predicate: " +
std::get<0>(deadness_predicates_tpl) + " Destination[" +
to_string(dst_encapsulate) + "] predicate: " +
std::get<1>(deadness_predicates_tpl) +
" Neighbours predicates: " +
absl::StrJoin(std::get<2>(deadness_predicates_tpl), "\n"));
}
}
reason_count_clusters[inner_itr]++;
} // end of the for over each cluster pair's reason vector
if (pair_has_reason) {
std::cout << src_encapsulate << "->" << dst_encapsulate << ": ";
for (int reason_id = 0; reason_id < num_reasons; reason_id++) {
if (!is_forbidden_reason(
static_cast<EdgeNonContractionReasons>(reason_id))) {
// Update global histogram with current pair's counts
reason_count_encapsulates[reason_id] +=
reason_count_encapsulates_for_pair[reason_id];
cout << reason_string[reason_id] << ":"
<< reason_count_encapsulates_for_pair[reason_id]
<< (reason_id < (num_reasons - 1) ? ", " : "");
}
}
std::cout << endl;
}
} // end of the for over cluster_separation_reason
std::cout << endl;
if (num_non_contracted != graph->num_edges()) {
return errors::Internal(
"Number of non contracted edges ", num_non_contracted,
" should match number of edges ", graph->num_edges());
}
auto print_reason_summary = [&num_reasons](
vector<int> reasons_count,
std::function<bool(EdgeNonContractionReasons)>
forbidden_reasons_filter) {
bool first = true;
for (int i = 0; i < num_reasons; i++) {
if (!forbidden_reasons_filter(
static_cast<EdgeNonContractionReasons>(i))) {
std::cout << (first ? "NGTF_SUMMARY: " : "") << reason_string[i]
<< ": " << reasons_count[i]
<< (i < (num_reasons - 1) ? ", " : "\n");
first = false;
}
}
};
std::cout << "NGTF_SUMMARY: Summary of reasons why a pair of edge "
"connected encapsulates did not merge\n";
print_reason_summary(reason_count_encapsulates, is_forbidden_reason);
std::cout << "NGTF_SUMMARY: Summary of reasons why a pair of edge "
"connected clusters did not merge\n";
print_reason_summary(reason_count_clusters,
[](EdgeNonContractionReasons x) {
NGRAPH_VLOG(5) << "EdgeNonContractionReasons: " << x;
return false;
});
}
return Status::OK();
}