bool TransposeSinking::run_on_function()

in ngraph_bridge/pass/transpose_sinking.cc [434:496]


bool TransposeSinking::run_on_function(shared_ptr<ngraph::Function> f) {
  TransposeMap reorders;
  set<shared_ptr<ngraph::Node>> transposes_to_delete;
  unordered_map<std::string, ngraph::Shape> orig_result_out_shape;

  if (utils::DumpAllGraphs()) {
    utils::DumpNGGraph(f, f->get_friendly_name() + "_before_TS");
  }

  // STEP 1 : Sink or Swim transposes away for op clusters
  for (auto n : f->get_ordered_ops()) {
    NGRAPH_VLOG(4) << "Processing " << n->get_name();
    // collect output shape of all Result nodes for a sanity check
    if (ngraph::op::is_output(n)) {
      orig_result_out_shape[n->get_name()] = n->get_output_shape(0);
    }
    if (auto transpose = ngraph::as_type_ptr<opset::Transpose>(n)) {
      sink_transpose(transpose, reorders, transposes_to_delete);
    } else if (ngraph::op::is_unary_elementwise_arithmetic(n)) {
      sink_unary(n, reorders, transposes_to_delete);
    } else if (ngraph::op::is_binary_elementwise_arithmetic(n)) {
      sink_binary(n, reorders, transposes_to_delete);
    } else if (auto pad = ngraph::as_type_ptr<opset::Pad>(n)) {
      sink_pad(pad, reorders, transposes_to_delete);
    } else if (auto concat = ngraph::as_type_ptr<opset::Concat>(n)) {
      sink_concat(concat, reorders, transposes_to_delete);
    } else {
      materialize_shapes(n, reorders, transposes_to_delete);
    }
  }

  // STEP 2: purge all the transposes we either sunk or swam.
  NGRAPH_VLOG(4) << "Purging transposes ";
  for (auto r : transposes_to_delete) {
    delete_transpose(r);
  }

  // STEP 3: fix wrong shape info wholesale
  NGRAPH_VLOG(4) << "Fixing wrong shape info for the whole graph";
  for (auto n : f->get_ordered_ops()) {
    n->revalidate_and_infer_types();
  }

  const ngraph::ResultVector& results = f->get_results();
  for (auto r : results) {
    // make sure shapes are always materialized before results
    NGRAPH_CHECK(
        r->get_shape() == r->get_input_shape(0) &&
            r->get_element_type() == r->input_value(0).get_element_type(),
        " op::Result = ", *r, ", Arg = ", r->input_value(0).get_node());

    // make sure that after TransposeSinking pass the output_shape for Result
    // does not change from the expected output_shape before the pass
    NGRAPH_CHECK(r->get_output_shape(0) == orig_result_out_shape[r->get_name()],
                 " op::Result = ", *r, " expected output shape = ",
                 orig_result_out_shape[r->get_name()]);
  }

  if (utils::DumpAllGraphs()) {
    utils::DumpNGGraph(f, f->get_friendly_name() + "_after_TS");
  }
  return true;
}