in ngraph_bridge/ngraph_builder.cc [977:1068]
static Status TranslateConv2DBackpropInputOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
ng::Output<ng::Node> ng_filter, ng_out_backprop, ng_unused;
TF_RETURN_IF_ERROR(
GetInputNodes(ng_op_map, op, ng_unused, ng_filter, ng_out_backprop));
// TODO: refactor me to be less redundant with other convolution ops
std::vector<int32> tf_strides;
std::vector<int32> tf_dilations;
std::string tf_padding_type;
std::string tf_data_format;
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "strides", &tf_strides));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "dilations", &tf_dilations));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "padding", &tf_padding_type));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "data_format", &tf_data_format));
if (tf_data_format != "NHWC" && tf_data_format != "NCHW") {
return errors::InvalidArgument(
"Conv2DBackpropInput data format is neither NHWC nor NCHW: %s",
tf_data_format);
}
std::vector<int64> tf_input_sizes;
TF_RETURN_IF_ERROR(
GetStaticInputVector(op, 0, static_input_map, &tf_input_sizes));
if (std::any_of(tf_input_sizes.begin(), tf_input_sizes.end(),
[](int32 size) { return size <= 0; })) {
return errors::InvalidArgument(
"Conv2DBackpropInput input sizes must be positive integers");
}
bool is_nhwc = (tf_data_format == "NHWC");
NGRAPH_VLOG(3) << ng::join(tf_strides);
NGRAPH_VLOG(3) << ng::join(tf_dilations);
NGRAPH_VLOG(3) << tf_padding_type;
NGRAPH_VLOG(3) << tf_data_format;
ng::Strides ng_strides(2);
ng::Strides ng_dilations(2);
ng::Shape ng_image_shape(2);
ng::Shape ng_kernel_shape(2);
ng::Shape ng_batch_shape(4);
NHWCtoHW(is_nhwc, tf_strides, ng_strides);
NHWCtoHW(is_nhwc, tf_dilations, ng_dilations);
NHWCtoHW(is_nhwc, tf_input_sizes, ng_image_shape);
NHWCtoNCHW(op->name(), is_nhwc, ng_out_backprop);
if (is_nhwc) {
ng_batch_shape = {static_cast<unsigned long>(tf_input_sizes[0]),
static_cast<unsigned long>(tf_input_sizes[3]),
static_cast<unsigned long>(tf_input_sizes[1]),
static_cast<unsigned long>(tf_input_sizes[2])};
} else {
ng_batch_shape = {static_cast<unsigned long>(tf_input_sizes[0]),
static_cast<unsigned long>(tf_input_sizes[1]),
static_cast<unsigned long>(tf_input_sizes[2]),
static_cast<unsigned long>(tf_input_sizes[3])};
}
NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides);
NGRAPH_VLOG(3) << "ng_dilations: " << ng::join(ng_dilations);
NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape);
auto& ng_filter_shape = ng_filter.get_shape();
ng_kernel_shape[0] = ng_filter_shape[0];
ng_kernel_shape[1] = ng_filter_shape[1];
Transpose<3, 2, 0, 1>(ng_filter);
Builder::SetTracingInfo(op->name(), ng_filter);
NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape);
ng::CoordinateDiff ng_padding_below;
ng::CoordinateDiff ng_padding_above;
Builder::MakePadding(tf_padding_type, ng_image_shape, ng_kernel_shape,
ng_strides, ng_dilations, ng_padding_below,
ng_padding_above);
auto ng_output_shape = ConstructNgNode<opset::Constant>(
op->name(), ng::element::i64, ng::Shape{ng_batch_shape.size() - 2},
vector<size_t>(ng_batch_shape.begin() + 2, ng_batch_shape.end()));
auto ng_data = ConstructNgNode<opset::ConvolutionBackpropData>(
op->name(), ng_out_backprop, ng_filter, ng_output_shape, ng_strides,
ng_padding_below, ng_padding_above, ng_dilations);
NCHWtoNHWC(op->name(), is_nhwc, ng_data);
SaveNgOp(ng_op_map, op->name(), ng_data);
return Status::OK();
}