in ngraph_bridge/ngraph_builder.cc [1189:1263]
static Status TranslateDepthwiseConv2dNativeOp(
const Node* op, const std::vector<const Tensor*>&,
Builder::OpMap& ng_op_map) {
ng::Output<ng::Node> ng_input, ng_filter;
TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, ng_input, ng_filter));
std::vector<int32> tf_strides;
std::vector<int32> tf_dilations;
std::string tf_padding_type;
std::string tf_data_format;
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "strides", &tf_strides));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "dilations", &tf_dilations));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "padding", &tf_padding_type));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "data_format", &tf_data_format));
if (tf_data_format != "NHWC" && tf_data_format != "NCHW") {
return errors::InvalidArgument(
"DepthwiseConv2D data format is neither NHWC nor NCHW");
}
bool is_nhwc = (tf_data_format == "NHWC");
NGRAPH_VLOG(3) << ng::join(tf_strides);
NGRAPH_VLOG(3) << ng::join(tf_dilations);
NGRAPH_VLOG(3) << tf_padding_type;
NGRAPH_VLOG(3) << tf_data_format;
ng::Strides ng_strides(2);
ng::Strides ng_dilations(2);
ng::Shape ng_image_shape(2);
ng::Shape ng_kernel_shape(2);
NHWCtoHW(is_nhwc, ng_input.get_shape(), ng_image_shape);
NHWCtoHW(is_nhwc, tf_strides, ng_strides);
NHWCtoHW(is_nhwc, tf_dilations, ng_dilations);
NHWCtoNCHW(op->name(), is_nhwc, ng_input);
NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides);
NGRAPH_VLOG(3) << "ng_dilations: " << ng::join(ng_dilations);
NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape);
auto& ng_filter_shape = ng_filter.get_shape();
ng_kernel_shape[0] = ng_filter_shape[0];
ng_kernel_shape[1] = ng_filter_shape[1];
NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape);
ng::CoordinateDiff ng_padding_below;
ng::CoordinateDiff ng_padding_above;
Builder::MakePadding(tf_padding_type, ng_image_shape, ng_kernel_shape,
ng_strides, ng_dilations, ng_padding_below,
ng_padding_above);
// H W I M -> H W I 1 M
auto filter_shape = ConstructNgNode<opset::Constant>(
op->name(), ng::element::u64, ng::Shape{5},
ngraph::Shape{ng_filter_shape[0], ng_filter_shape[1], ng_filter_shape[2],
1, ng_filter_shape[3]});
auto reshaped_filter = ConstructNgNode<opset::Reshape>(op->name(), ng_filter,
filter_shape, false);
// H W I 1 M -> I M 1 H W
auto order = ConstructNgNode<opset::Constant>(
op->name(), ng::element::i64, ng::Shape{5}, vector<int64>{2, 4, 3, 0, 1});
auto transposed_filter =
ConstructNgNode<opset::Transpose>(op->name(), reshaped_filter, order);
auto ng_conv = ConstructNgNode<opset::GroupConvolution>(
op->name(), ng_input, transposed_filter, ng_strides, ng_padding_below,
ng_padding_above, ng_dilations);
NCHWtoNHWC(op->name(), is_nhwc, ng_conv);
SaveNgOp(ng_op_map, op->name(), ng_conv);
return Status::OK();
}