in tensorflow/tensorflow/compiler/tf2xla/functionalize_while.cc [263:561]
Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
Graph* graph, WhileLoopFrame* frame,
FunctionLibraryDefinition* library) {
VLOG(2) << "Frame " << frame->name << " before: "
<< DumpGraphToFile("functionalize_before", *graph, library);
// Split loop-varying Enter nodes with multiple successors. If the same
// Tensor is fed as input to multiple loop arguments, we may end up with a
// shared Enter node. We clone Enter nodes with multiple successors to
// maintain the invariant of a unique Enter node per argument of the final
// loop.
std::vector<WhileLoopArg> args;
for (const WhileLoopArg& arg : frame->args) {
if (arg.is_loop_invariant) {
args.push_back(arg);
} else {
std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
arg.enter->out_edges().end());
for (int i = 0; i < edges.size(); ++i) {
if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
continue;
}
TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
WhileLoopArg new_arg;
new_arg.is_loop_invariant = false;
if (i == 0) {
new_arg.enter = arg.enter;
} else {
new_arg.enter = graph->CopyNode(arg.enter);
frame->nodes.insert(new_arg.enter);
for (Edge const* e : arg.enter->in_edges()) {
graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
e->IsControlEdge() ? Graph::kControlSlot : 0);
}
Node* dst = edges[i]->dst();
int dst_input = edges[i]->dst_input();
graph->RemoveEdge(edges[i]);
graph->AddEdge(new_arg.enter, 0, dst, dst_input);
}
args.push_back(new_arg);
}
}
}
frame->args = std::move(args);
std::sort(frame->args.begin(), frame->args.end(),
[](const WhileLoopArg& a, const WhileLoopArg& b) {
return NodeCmpByNameResourcesLast()(a.enter, b.enter);
});
if (frame->loop_cond == nullptr) {
return errors::InvalidArgument("Loop ", frame->name,
" has no LoopCond node");
}
// Find the set of Switch nodes that are successors of the LoopCond.
std::unordered_set<Node*> switches;
for (const Edge* edge : frame->loop_cond->out_edges()) {
if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
edge->dst_input() == 1) {
switches.insert(edge->dst());
}
}
// For each non-constant argument, looks for the following pattern of nodes:
// Enter ----> Merge --------> Switch --> Exit
// ^ ^
// | |
// NextIteration LoopCond
// ^ ^
// | |
// ... ...
for (WhileLoopArg& arg : frame->args) {
if (!arg.is_loop_invariant) {
// Follow the edge from the Enter to Merge.
const Edge* enter_merge = nullptr;
for (const Edge* e : arg.enter->out_edges()) {
// Ignore control-edges to the sink node. These are allowed by the
// graph invariants, although probably they should have been stripped
// off earlier.
if (e->IsControlEdge() && e->dst()->IsSink()) {
continue;
}
if (enter_merge != nullptr) {
return errors::Internal("Enter node for loop-varying argument ",
FormatNodeForError(*arg.enter),
" has multiple successors: ",
FormatNodeForError(*enter_merge->dst()),
" and ", FormatNodeForError(*e->dst()));
}
enter_merge = e;
}
if (enter_merge == nullptr) {
return errors::Internal("Enter node for loop-varying argument ",
FormatNodeForError(*arg.enter),
" has zero successors");
}
arg.merge = enter_merge->dst();
if (!IsMerge(arg.merge)) {
return errors::InvalidArgument(
"Successor of Enter node for loop-varying argument ",
FormatNodeForError(*arg.merge),
" is not a Merge node; got: ", arg.merge->type_string());
}
// Find the NextIteration from the merge. There should be two inputs to
// the Merge and the NextIteration should be the other input.
if (arg.merge->input_types().size() != 2) {
return errors::InvalidArgument(
"Unexpected number of inputs to Merge node for loop-varying "
"argument ",
FormatNodeForError(*arg.merge), "; expected 2, got ",
arg.merge->input_types().size());
}
TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
&arg.next_iteration));
if (!IsNextIteration(arg.next_iteration)) {
return errors::InvalidArgument(
"Expected NextIteration node as input to Merge node; got node ",
FormatNodeForError(*arg.next_iteration), " with kind ",
arg.next_iteration->type_string());
}
// Find the Switch successor of the Merge. There should be exactly one
// Switch node that is a successor of both the Merge and the LoopCond.
for (const Edge* edge : arg.merge->out_edges()) {
if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
switches.find(edge->dst()) != switches.end()) {
if (arg.switch_node != nullptr) {
return errors::InvalidArgument("Duplicate Switch successors to ",
FormatNodeForError(*arg.merge));
}
arg.switch_node = edge->dst();
}
}
if (arg.switch_node == nullptr) {
return errors::InvalidArgument("Missing Switch successor to ",
FormatNodeForError(*arg.merge));
}
// Update the device on the Identity outputs of the switch to match their
// target. These Identity outputs do not
// Loop over the switch node's output to:
// - Find the Exit successor.
// - Set the sharding on all Identity outputs of the switch. These
// identity nodes are values used by the loop body or condition.
// The Identity node may have the wrong device so copy the device from
// one of its outputs instead.
std::deque<const Edge*> possible_exit;
for (const Edge* edge : arg.switch_node->out_edges()) {
if (edge->src_output() == 0) {
possible_exit.push_back(edge);
}
if (IsIdentity(edge->dst())) {
TF_RETURN_IF_ERROR(
SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
}
}
// TODO(b/67425339): Allow general graph between switch and exit.
while (!possible_exit.empty()) {
const Edge* edge = possible_exit.front();
possible_exit.pop_front();
if (IsExit(edge->dst())) {
if (arg.exit != nullptr) {
return errors::InvalidArgument(
"Duplicate Exit successors to ",
FormatNodeForError(*arg.switch_node));
}
arg.exit = edge->dst();
} else {
if (!IsIdentity(edge->dst())) {
return errors::Unimplemented("General graph between switch (",
FormatNodeForError(*arg.switch_node),
") and exit node of frame ",
frame->name, " not supported yet.");
}
for (const Edge* out : edge->dst()->out_edges()) {
possible_exit.push_back(out);
}
}
}
}
}
// Builds the condition and body functions. Notice that we call
// FunctionalizeCond() on cond_graph and body_graph because we might have
// unfunctionalized "if" in cond_graph and body_graph. Functionalize them
// before they are encapsulated in FunctionDef.
std::unique_ptr<Graph> cond_graph;
TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
FixupSourceAndSinkEdges(cond_graph.get());
TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library));
DataTypeVector arg_types;
std::unique_ptr<Graph> body_graph;
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
FixupSourceAndSinkEdges(body_graph.get());
TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library));
VLOG(2) << "Frame " << frame->name << " condition: "
<< DumpGraphToFile("loop_condition", *cond_graph, library)
<< " body: " << DumpGraphToFile("loop_body", *body_graph);
static std::atomic<int64> sequence_num(0LL);
int64 id = ++sequence_num;
NameAttrList cond_name;
cond_name.set_name(absl::StrCat("_functionalize_cond_", id));
NameAttrList body_name;
body_name.set_name(absl::StrCat("_functionalize_body_", id));
FunctionDef cond_fdef;
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
FunctionDef body_fdef;
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
if (lookup_library) {
// Copy missing FunctionDefs from lookup_library to library to make library
// self-contained.
TF_RETURN_IF_ERROR(
AddMissingFunctionDef(cond_fdef, lookup_library, library));
TF_RETURN_IF_ERROR(
AddMissingFunctionDef(body_fdef, lookup_library, library));
}
// Builds a While operator.
NodeDef while_def;
NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
string outside_compilation;
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
&outside_compilation)
.ok()) {
builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
}
std::vector<NodeDefBuilder::NodeOut> inputs;
for (int i = 0; i < frame->args.size(); ++i) {
const WhileLoopArg& arg = frame->args[i];
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
builder.ControlInput(in_edge->src()->name());
} else {
inputs.push_back(NodeDefBuilder::NodeOut(
in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
}
}
builder.Input(inputs);
TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph));
// Copies edges to the Enter nodes and from the Exit nodes onto the While.
for (int i = 0; i < frame->args.size(); ++i) {
const WhileLoopArg& arg = frame->args[i];
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
graph->AddControlEdge(in_edge->src(), while_node);
} else {
graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
}
if (!arg.is_loop_invariant) {
// Add output edges if the output of the loop is consumed.
if (arg.exit != nullptr) {
std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
arg.exit->out_edges().end());
for (const Edge* edge : edges) {
Node* dst = edge->dst();
int dst_input = edge->dst_input();
graph->RemoveEdge(edge);
if (dst_input == Graph::kControlSlot) {
graph->AddControlEdge(while_node, dst);
} else {
graph->AddEdge(while_node, i, dst, dst_input);
}
}
}
}
}
// Remove the old nodes from the graph, and add the while node to the parent
// frame.
for (Node* node : frame->nodes) {
graph->RemoveNode(node);
}
frame->nodes.clear();
frame->parent->nodes.insert(while_node);
VLOG(2) << "Frame " << frame->name << " after: "
<< DumpGraphToFile("functionalize_after", *graph, library);
return Status::OK();
}