in ngraph_bridge/ngraph_builder.cc [2838:3001]
Status Builder::TranslateGraph(
const std::vector<TensorShape>& inputs,
const std::vector<const Tensor*>& static_input_map,
const Graph* input_graph, const string name,
shared_ptr<ng::Function>& ng_function) {
//
// We will visit ops in topological order.
//
// ought to be `const Node*`, but GetReversePostOrder doesn't use `const`
vector<Node*> ordered;
GetReversePostOrder(*input_graph, &ordered, NodeComparatorName());
//
// Split ops into params, retvals, and all others.
//
vector<const Node*> tf_params;
vector<const Node*> tf_ret_vals;
vector<const Node*> tf_ops;
for (const auto n : ordered) {
if (n->IsSink() || n->IsSource()) {
continue;
}
if (n->IsControlFlow()) {
return errors::Unimplemented(
"Encountered a control flow op in the nGraph bridge: ",
n->DebugString());
}
if (n->IsArg()) {
tf_params.push_back(n);
} else if (n->IsRetval()) {
tf_ret_vals.push_back(n);
} else {
tf_ops.push_back(n);
}
}
//
// The op map holds a mapping from TensorFlow op names (strings) to
// vector of generated nGraph Output<Node>.
//
Builder::OpMap ng_op_map;
//
// Populate the parameter list, and also put parameters into the op map.
//
ng::ParameterVector ng_parameter_list(tf_params.size());
for (auto parm : tf_params) {
DataType dtype;
if (GetNodeAttr(parm->attrs(), "T", &dtype) != Status::OK()) {
return errors::InvalidArgument("No data type defined for _Arg");
}
int index;
if (GetNodeAttr(parm->attrs(), "index", &index) != Status::OK()) {
return errors::InvalidArgument("No index defined for _Arg");
}
ng::element::Type ng_et;
TF_RETURN_IF_ERROR(tf_utils::TFDataTypeToNGraphElementType(dtype, &ng_et));
ng::Shape ng_shape;
TF_RETURN_IF_ERROR(
tf_utils::TFTensorShapeToNGraphShape(inputs[index], &ng_shape));
string prov_tag;
GetNodeAttr(parm->attrs(), "_prov_tag", &prov_tag);
auto ng_param =
ConstructNgNode<opset::Parameter>(prov_tag, ng_et, ng_shape);
SaveNgOp(ng_op_map, parm->name(), ng_param);
ng_parameter_list[index] =
ngraph::as_type_ptr<opset::Parameter>(ng_param.get_node_shared_ptr());
}
//
// Now create the nGraph ops from TensorFlow ops.
//
for (auto op : tf_ops) {
NGRAPH_VLOG(2) << "Constructing op " << op->name() << " which is "
<< op->type_string();
const function<Status(const Node*, const std::vector<const Tensor*>&,
Builder::OpMap&)>* op_fun;
try {
op_fun = &(TRANSLATE_OP_MAP.at(op->type_string()));
} catch (const std::out_of_range&) {
// -----------------------------
// Catch-all for unsupported ops
// -----------------------------
NGRAPH_VLOG(3) << "No translation handler registered for op: "
<< op->name() << " (" << op->type_string() << ")";
NGRAPH_VLOG(3) << op->def().DebugString();
return errors::InvalidArgument(
"No translation handler registered for op: ", op->name(), " (",
op->type_string(), ")\n", op->def().DebugString());
}
try {
TF_RETURN_IF_ERROR((*op_fun)(op, static_input_map, ng_op_map));
} catch (const std::exception& e) {
return errors::Internal("Unhandled exception in op handler: ", op->name(),
" (", op->type_string(), ")\n",
op->def().DebugString(), "\n", "what(): ",
e.what());
}
}
//
// Populate the result list.
//
ng::ResultVector ng_result_list(tf_ret_vals.size());
for (auto n : tf_ret_vals) {
// Make sure that this _Retval only has one input node.
if (n->num_inputs() != 1) {
return errors::InvalidArgument("_Retval has ", n->num_inputs(),
" inputs, should have 1");
}
int index;
if (GetNodeAttr(n->attrs(), "index", &index) != Status::OK()) {
return errors::InvalidArgument("No index defined for _Retval");
}
ng::Output<ng::Node> result;
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, n, 0, result));
auto ng_result = ConstructNgNode<opset::Result>(n->name(), result);
ng_result_list[index] =
ngraph::as_type_ptr<opset::Result>(ng_result.get_node_shared_ptr());
}
//
// Create the nGraph function.
//
ng_function =
make_shared<ng::Function>(ng_result_list, ng_parameter_list, name);
//
// Apply additional passes on the nGraph function here.
//
{
ngraph::pass::Manager passes;
if (utils::GetEnv("NGRAPH_TF_CONSTANT_FOLDING") == "1") {
passes.register_pass<ngraph::pass::ConstantFolding>();
}
if (utils::GetEnv("NGRAPH_TF_TRANSPOSE_SINKING") != "0") {
passes.register_pass<pass::TransposeSinking>();
}
passes.run_passes(ng_function);
}
NGRAPH_VLOG(5) << "Done with passes";
//
// Request row-major layout on results.
//
for (auto result : ng_function->get_results()) {
result->set_needs_default_layout(true);
}
NGRAPH_VLOG(5) << "Done with translations";
return Status::OK();
}