static Status TranslateDepthwiseConv2dNativeOp()

in ngraph_bridge/ngraph_builder.cc [1189:1263]


static Status TranslateDepthwiseConv2dNativeOp(
    const Node* op, const std::vector<const Tensor*>&,
    Builder::OpMap& ng_op_map) {
  ng::Output<ng::Node> ng_input, ng_filter;
  TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, ng_input, ng_filter));

  std::vector<int32> tf_strides;
  std::vector<int32> tf_dilations;
  std::string tf_padding_type;
  std::string tf_data_format;
  TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "strides", &tf_strides));
  TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "dilations", &tf_dilations));
  TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "padding", &tf_padding_type));
  TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "data_format", &tf_data_format));

  if (tf_data_format != "NHWC" && tf_data_format != "NCHW") {
    return errors::InvalidArgument(
        "DepthwiseConv2D data format is neither NHWC nor NCHW");
  }

  bool is_nhwc = (tf_data_format == "NHWC");

  NGRAPH_VLOG(3) << ng::join(tf_strides);
  NGRAPH_VLOG(3) << ng::join(tf_dilations);
  NGRAPH_VLOG(3) << tf_padding_type;
  NGRAPH_VLOG(3) << tf_data_format;

  ng::Strides ng_strides(2);
  ng::Strides ng_dilations(2);
  ng::Shape ng_image_shape(2);
  ng::Shape ng_kernel_shape(2);

  NHWCtoHW(is_nhwc, ng_input.get_shape(), ng_image_shape);
  NHWCtoHW(is_nhwc, tf_strides, ng_strides);
  NHWCtoHW(is_nhwc, tf_dilations, ng_dilations);
  NHWCtoNCHW(op->name(), is_nhwc, ng_input);

  NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides);
  NGRAPH_VLOG(3) << "ng_dilations: " << ng::join(ng_dilations);
  NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape);

  auto& ng_filter_shape = ng_filter.get_shape();
  ng_kernel_shape[0] = ng_filter_shape[0];
  ng_kernel_shape[1] = ng_filter_shape[1];

  NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape);

  ng::CoordinateDiff ng_padding_below;
  ng::CoordinateDiff ng_padding_above;
  Builder::MakePadding(tf_padding_type, ng_image_shape, ng_kernel_shape,
                       ng_strides, ng_dilations, ng_padding_below,
                       ng_padding_above);

  // H W I M -> H W I 1 M
  auto filter_shape = ConstructNgNode<opset::Constant>(
      op->name(), ng::element::u64, ng::Shape{5},
      ngraph::Shape{ng_filter_shape[0], ng_filter_shape[1], ng_filter_shape[2],
                    1, ng_filter_shape[3]});
  auto reshaped_filter = ConstructNgNode<opset::Reshape>(op->name(), ng_filter,
                                                         filter_shape, false);

  // H W I 1 M -> I M 1 H W
  auto order = ConstructNgNode<opset::Constant>(
      op->name(), ng::element::i64, ng::Shape{5}, vector<int64>{2, 4, 3, 0, 1});
  auto transposed_filter =
      ConstructNgNode<opset::Transpose>(op->name(), reshaped_filter, order);

  auto ng_conv = ConstructNgNode<opset::GroupConvolution>(
      op->name(), ng_input, transposed_filter, ng_strides, ng_padding_below,
      ng_padding_above, ng_dilations);

  NCHWtoNHWC(op->name(), is_nhwc, ng_conv);
  SaveNgOp(ng_op_map, op->name(), ng_conv);
  return Status::OK();
}