in ngraph_bridge/pass/transpose_sinking.cc [434:496]
bool TransposeSinking::run_on_function(shared_ptr<ngraph::Function> f) {
TransposeMap reorders;
set<shared_ptr<ngraph::Node>> transposes_to_delete;
unordered_map<std::string, ngraph::Shape> orig_result_out_shape;
if (utils::DumpAllGraphs()) {
utils::DumpNGGraph(f, f->get_friendly_name() + "_before_TS");
}
// STEP 1 : Sink or Swim transposes away for op clusters
for (auto n : f->get_ordered_ops()) {
NGRAPH_VLOG(4) << "Processing " << n->get_name();
// collect output shape of all Result nodes for a sanity check
if (ngraph::op::is_output(n)) {
orig_result_out_shape[n->get_name()] = n->get_output_shape(0);
}
if (auto transpose = ngraph::as_type_ptr<opset::Transpose>(n)) {
sink_transpose(transpose, reorders, transposes_to_delete);
} else if (ngraph::op::is_unary_elementwise_arithmetic(n)) {
sink_unary(n, reorders, transposes_to_delete);
} else if (ngraph::op::is_binary_elementwise_arithmetic(n)) {
sink_binary(n, reorders, transposes_to_delete);
} else if (auto pad = ngraph::as_type_ptr<opset::Pad>(n)) {
sink_pad(pad, reorders, transposes_to_delete);
} else if (auto concat = ngraph::as_type_ptr<opset::Concat>(n)) {
sink_concat(concat, reorders, transposes_to_delete);
} else {
materialize_shapes(n, reorders, transposes_to_delete);
}
}
// STEP 2: purge all the transposes we either sunk or swam.
NGRAPH_VLOG(4) << "Purging transposes ";
for (auto r : transposes_to_delete) {
delete_transpose(r);
}
// STEP 3: fix wrong shape info wholesale
NGRAPH_VLOG(4) << "Fixing wrong shape info for the whole graph";
for (auto n : f->get_ordered_ops()) {
n->revalidate_and_infer_types();
}
const ngraph::ResultVector& results = f->get_results();
for (auto r : results) {
// make sure shapes are always materialized before results
NGRAPH_CHECK(
r->get_shape() == r->get_input_shape(0) &&
r->get_element_type() == r->input_value(0).get_element_type(),
" op::Result = ", *r, ", Arg = ", r->input_value(0).get_node());
// make sure that after TransposeSinking pass the output_shape for Result
// does not change from the expected output_shape before the pass
NGRAPH_CHECK(r->get_output_shape(0) == orig_result_out_shape[r->get_name()],
" op::Result = ", *r, " expected output shape = ",
orig_result_out_shape[r->get_name()]);
}
if (utils::DumpAllGraphs()) {
utils::DumpNGGraph(f, f->get_friendly_name() + "_after_TS");
}
return true;
}