in tensorflow/tensorflow/core/graph/graph_partition.cc [977:1250]
Status Partition(const PartitionOptions& opts, Graph* g,
std::unordered_map<string, GraphDef>* partitions) {
Status status;
partitions->clear();
GraphInfo g_info;
if (!opts.control_flow_added) {
// Add the "code" for distributed execution of control flow. Code is
// added only for the frames that are placed on multiple devices. The
// new graph is an equivalent transformation of the original graph and
// has the property that it can be subsequently partitioned arbitrarily
// (down to the level of individual device) for distributed execution.
status = AddControlFlow(opts, g, &g_info);
if (!status.ok()) return status;
}
// At this point, all the graph mutations have been done. Build memory
// and device type info for every node and edge in the graph.
status = BuildMemoryDeviceInfo(*g, &g_info);
if (!status.ok()) return status;
string dstp;
std::vector<const Edge*> inputs;
DupRecvTable dup_recv(3);
// For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
// edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
// edge to dst. We will add a control edge for every pair in
// (ref_recvs x ref_control_inputs).
std::vector<NodeDef*> ref_recvs;
std::vector<string> ref_control_inputs;
int32 num_data = 0;
int32 num_control = 0;
for (const Node* dst : g->op_nodes()) {
dstp = opts.node_to_loc(dst);
GraphDef* dst_graph = &(*partitions)[dstp];
NodeDef* dst_def = dst_graph->add_node();
*dst_def = dst->def();
MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def);
dst_def->set_device(dst->assigned_device_name());
dst_def->clear_input(); // Inputs are filled below
if (opts.need_to_record_start_times) {
int64 start_time;
status = GetNodeAttr(*dst_def, "_start_time", &start_time);
if (errors::IsNotFound(status)) {
start_time = opts.start_times[dst->id()].value();
AddNodeAttr("_start_time", start_time, dst_def);
} else if (!status.ok()) {
return status;
}
}
// Arrange the incoming edges to dst so that input[i] holds the
// input flowing into slot numbered i. Trailing entries in input[]
// hold control edges.
inputs.clear();
inputs.resize(dst->num_inputs(), nullptr);
ref_recvs.clear();
ref_control_inputs.clear();
const Edge* control_flow_edge = nullptr;
int32 num_control_flow_edges = 0;
int32 num_input_edges = 0;
for (const Edge* edge : dst->in_edges()) {
if (edge->IsControlEdge()) {
if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
// This is one of the control edges added for control flow. There
// can be multiple such edges as the dest node may have multiple
// remote inputs. We keep track of the number of such edges.
control_flow_edge = edge;
++num_control_flow_edges;
} else {
inputs.push_back(edge);
}
} else {
DCHECK(inputs[edge->dst_input()] == nullptr);
inputs[edge->dst_input()] = edge;
++num_input_edges;
}
}
if (num_input_edges != dst->num_inputs()) {
return errors::InvalidArgument("Incomplete graph, missing ",
(dst->num_inputs() - num_input_edges),
" inputs for ", dst->name());
}
// Process in order so that all data edges are added as inputs to
// dst in Edge::dst_input() order.
for (const Edge* edge : inputs) {
const Node* src = edge->src();
if (!src->IsOp()) continue; // Skip Sink/Source nodes.
GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
// Same partition and compatible memory types:
AddInput(dst_def, src->name(), edge->src_output());
if (edge->IsControlEdge() ||
!IsRefType(src->output_type(edge->src_output()))) {
ref_control_inputs.push_back(src->name());
}
continue;
}
int64 send_start_time = 0;
int64 recv_start_time = 0;
if (opts.scheduling_for_recvs) {
status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
send_start_time = opts.start_times[src->id()].value();
} else if (!status.ok()) {
return status;
}
status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
recv_start_time = opts.start_times[dst->id()].value();
} else if (!status.ok()) {
return status;
}
}
// Check whether there is already a send/recv pair transferring
// the same tensor/control from the src to dst partition.
const bool on_host = IsDstInputOnHost(edge, g_info);
DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
auto iter = dup_recv.find(key);
if (iter != dup_recv.end()) {
// We found one. Reuse the data/control transferred already.
const string& recv_node_name = iter->second.recv->name();
if (edge->IsControlEdge()) {
AddInput(dst_def, recv_node_name, Graph::kControlSlot);
} else {
AddInput(dst_def, recv_node_name, 0);
}
ref_control_inputs.push_back(recv_node_name);
// We want the start_time for the recv to be the smallest of the start
// times of it's consumers. So we update this whenever we use a recv,
// and write it out to the attribute at the end of the subroutine
if (iter->second.start_time > recv_start_time) {
iter->second.start_time = recv_start_time;
}
continue;
}
NodeDefBuilder::NodeOut send_from;
if (edge->IsControlEdge()) {
// Insert a dummy const node that will generate a tiny
// data element to be sent from send to recv.
VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
<< src->name() << "] -> " << dst->assigned_device_name() << "["
<< dst->name() << "]";
NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
if (!status.ok()) return status;
// Set the start time for this dummy node.
if (opts.scheduling_for_recvs) {
AddNodeAttr("_start_time", send_start_time, dummy);
}
AddInput(dummy, src->name(), Graph::kControlSlot);
send_from.Reset(dummy->name(), 0, DT_FLOAT);
} else {
send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
}
// Need to split edge by placing matching send/recv nodes on
// the src/dst sides of the edge.
NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
send_start_time, &status);
if (!status.ok()) return status;
NodeDef* real_recv = nullptr;
NodeDef* recv =
AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
if (!status.ok()) return status;
// Fix up the control flow edge.
// NOTE(yuanbyu): 'real_recv' must be the real recv node.
if (src_graph == dst_graph) {
// For same device send/recv, add a control edge from send to recv.
// This prevents the asynchronous recv kernel from being scheduled
// before the data is available.
AddInput(real_recv, send->name(), Graph::kControlSlot);
} else if (control_flow_edge != nullptr) {
// Redirect control edge to the real recv since this is not the same
// device send/recv.
--num_control_flow_edges;
AddInput(real_recv, control_flow_edge->src()->name(),
Graph::kControlSlot);
}
if (!edge->IsControlEdge() &&
IsRefType(src->output_type(edge->src_output()))) {
AddNodeAttr("_start_time", recv_start_time, recv);
if (real_recv != recv) {
AddNodeAttr("_start_time", recv_start_time, real_recv);
}
// If src is of ref type and the edge is not a control edge, dst has
// read semantics and therefore we must control the recv.
ref_recvs.push_back(real_recv);
} else {
// Memorize the send/recv pair, only if this is not a "ref" edge.
// NOTE(yuanbyu): Collapsing ref edges requires extreme care so
// for now we don't do it.
dup_recv[key] = {recv, real_recv, recv_start_time};
ref_control_inputs.push_back(recv->name());
}
if (edge->IsControlEdge()) {
++num_control;
AddInput(dst_def, recv->name(), Graph::kControlSlot);
} else {
++num_data;
AddInput(dst_def, recv->name(), 0);
}
}
// Add control edges from 'ref_control_inputs' to 'ref_recvs'.
// NOTE(yuanbyu): Adding these control edges should not introduce
// deadlocks. 'dst' has implicit "read" nodes that, when we split
// across devices, are made explicit; Retargeting the dependencies
// to 'dst' to those nodes would not introduce cycles if there isn't
// one before the transformation.
// NOTE(yuanbyu): This may impact performance because it defers the
// execution of recvs until all the other inputs become available.
AddReadControl(ref_recvs, ref_control_inputs);
// Add back the control edges for control flow that are not used.
if (control_flow_edge != nullptr) {
for (int i = 0; i < num_control_flow_edges; ++i) {
AddInput(dst_def, control_flow_edge->src()->name(),
Graph::kControlSlot);
}
}
}
const FunctionLibraryDefinition* flib_def = opts.flib_def;
if (flib_def == nullptr) {
flib_def = &g->flib_def();
}
// Set versions, function library and send/recv incarnation.
for (auto& it : *partitions) {
GraphDef* gdef = &it.second;
*gdef->mutable_versions() = g->versions();
// Prune unreachable functions from `flib_def` before adding them to `gdef`.
*gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
// Traverse the graph to fill every send/recv op's incarnation
// information.
SetIncarnation(opts, gdef);
}
// Set the start times for recvs at the very end.
if (opts.scheduling_for_recvs) {
for (auto& it : dup_recv) {
AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
if (it.second.real_recv != it.second.recv) {
AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
}
}
}
VLOG(1) << "Added send/recv: controls=" << num_control
<< ", data=" << num_data;
if (VLOG_IS_ON(2)) {
for (auto& it : *partitions) {
GraphDef* gdef = &it.second;
DumpGraphDefToFile(strings::StrCat("partition_", it.first, "_",
reinterpret_cast<uintptr_t>(gdef)),
*gdef);
}
}
return Status::OK();
}