in ngraph_bridge/ngraph_builder.cc [1457:1650]
static Status TranslateFusedConv2DOp(const Node* op,
const std::vector<const Tensor*>&,
Builder::OpMap& ng_op_map) {
int num_args;
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "num_args", &num_args));
NGRAPH_VLOG(3) << "num_args: " << num_args;
std::vector<string> fused_ops;
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "fused_ops", &fused_ops));
std::string tf_data_format;
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "data_format", &tf_data_format));
bool is_nhwc = (tf_data_format == "NHWC");
auto CreateNgConv = [&](ng::Output<ng::Node>& ng_input,
ng::Output<ng::Node>& ng_filter,
ng::Output<ng::Node>& ng_conv) {
std::vector<int32> tf_strides;
std::vector<int32> tf_dilations;
std::string tf_padding_type;
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "strides", &tf_strides));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "dilations", &tf_dilations));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "padding", &tf_padding_type));
if (tf_data_format != "NHWC" && tf_data_format != "NCHW") {
return errors::InvalidArgument(
"Conv2D data format is neither NHWC nor NCHW");
}
// TF Kernel Test Checks
// Strides in the batch and depth dimension is not supported
if (tf_strides[0] != 1 || tf_strides[is_nhwc ? 3 : 1] != 1) {
return errors::InvalidArgument(
"Strides in batch and depth dimensions is not supported: ",
op->type_string());
}
NGRAPH_VLOG(3) << ng::join(tf_strides);
NGRAPH_VLOG(3) << ng::join(tf_dilations);
NGRAPH_VLOG(3) << tf_padding_type;
NGRAPH_VLOG(3) << tf_data_format;
ng::Strides ng_strides(2);
ng::Strides ng_dilations(2);
ng::Shape ng_image_shape(2);
ng::Shape ng_kernel_shape(2);
NHWCtoHW(is_nhwc, tf_strides, ng_strides);
NHWCtoHW(is_nhwc, ng_input.get_shape(), ng_image_shape);
NHWCtoHW(is_nhwc, tf_dilations, ng_dilations);
NHWCtoNCHW(op->name(), is_nhwc, ng_input);
NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides);
NGRAPH_VLOG(3) << "ng_dilations: " << ng::join(ng_dilations);
NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape);
auto& ng_filter_shape = ng_filter.get_shape();
ng_kernel_shape[0] = ng_filter_shape[0];
ng_kernel_shape[1] = ng_filter_shape[1];
Transpose<3, 2, 0, 1>(ng_filter);
Builder::SetTracingInfo(op->name(), ng_filter);
NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape);
ng::CoordinateDiff ng_padding_below;
ng::CoordinateDiff ng_padding_above;
Builder::MakePadding(tf_padding_type, ng_image_shape, ng_kernel_shape,
ng_strides, ng_dilations, ng_padding_below,
ng_padding_above);
ng_conv = ConstructNgNode<opset::Convolution>(
op->name() + "_FusedConv2D_Conv", ng_input, ng_filter, ng_strides,
ng_padding_below, ng_padding_above, ng_dilations);
return Status::OK();
};
auto CreateNgBiasAdd = [&](ng::Output<ng::Node>& ng_conv,
ng::Output<ng::Node>& ng_bias,
ng::Output<ng::Node>& ng_bias_add) {
auto ng_conv_shape = ng_conv.get_shape();
auto ng_bias_shape = ng_bias.get_shape();
if (ng_bias_shape.size() != 1) {
return errors::InvalidArgument(
"Bias argument to BiasAdd does not have one dimension");
}
std::vector<size_t> reshape_pattern_values(ng_conv_shape.size(), 1U);
reshape_pattern_values[1] = ng_bias.get_shape().front();
auto reshape_pattern = make_shared<opset::Constant>(
ng::element::u64, ng::Shape{reshape_pattern_values.size()},
reshape_pattern_values);
auto ng_bias_reshaped = ConstructNgNode<opset::Reshape>(
op->name(), ng_bias, reshape_pattern, false);
ng_bias_add = ConstructNgNode<opset::Add>(
op->name() + "_FusedConv2D_BiasAdd", ng_conv, ng_bias_reshaped);
return Status::OK();
};
if (VecStrCmp(fused_ops, {"BiasAdd"}) ||
VecStrCmp(fused_ops, {"BiasAdd", "Relu"}) ||
VecStrCmp(fused_ops, {"BiasAdd", "Relu6"})) {
if (num_args != 1) {
return errors::InvalidArgument(
"FusedConv2DBiasAdd has incompatible num_args");
}
} else if (VecStrCmp(fused_ops, {"BiasAdd", "Add", "Relu"}) ||
VecStrCmp(fused_ops, {"BiasAdd", "Add"})) {
if (num_args != 2) {
return errors::InvalidArgument(
"FusedConv2DBiasAddAdd has incompatible num_args");
}
} else if (VecStrCmp(fused_ops, {"FusedBatchNorm"}) ||
VecStrCmp(fused_ops, {"FusedBatchNorm", "Relu"}) ||
VecStrCmp(fused_ops, {"FusedBatchNorm", "Relu6"})) {
if (num_args != 4) {
return errors::InvalidArgument(
"FusedConv2D with FusedBatchNorm has incompatible num_args");
}
} else {
return errors::Unimplemented("Unsupported _FusedConv2D " +
absl::StrJoin(fused_ops, ","));
}
// Conv2D
ng::Output<ng::Node> ng_input, ng_filter, ng_conv;
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, ng_input));
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 1, ng_filter));
TF_RETURN_IF_ERROR(CreateNgConv(ng_input, ng_filter, ng_conv));
// BiasAdd or BatchNorm
ng::Output<ng::Node> ng_fused_op_0, ng_bias, ng_scale, ng_offset, ng_mean,
ng_variance;
if (fused_ops[0] == "BiasAdd") {
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 2, ng_bias));
TF_RETURN_IF_ERROR(CreateNgBiasAdd(ng_conv, ng_bias, ng_fused_op_0));
} else if (fused_ops[0] == "FusedBatchNorm") {
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 2, ng_scale));
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 3, ng_offset));
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 4, ng_mean));
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 5, ng_variance));
float tf_epsilon;
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "epsilon", &tf_epsilon));
ng_fused_op_0 = ConstructNgNode<opset::BatchNormInference>(
op->name() + "_FusedConv2D_BatchNorm", ng_conv, ng_scale, ng_offset,
ng_mean, ng_variance, tf_epsilon);
} else {
// shouldn't come
return errors::Unimplemented("Unsupported _FusedConv2D " +
absl::StrJoin(fused_ops, ","));
}
ng::Output<ng::Node> ng_input_add, ng_fused_op_1;
if (fused_ops.size() == 1) {
NCHWtoNHWC(op->name(), is_nhwc, ng_fused_op_0);
SaveNgOp(ng_op_map, op->name(), ng_fused_op_0);
return Status::OK();
} else {
// Add or Relu or Relu6
if (fused_ops[1] == "Add") {
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 3, ng_input_add));
NHWCtoNCHW(op->name(), is_nhwc, ng_input_add);
ng_fused_op_1 = ConstructNgNode<opset::Add>(
op->name() + "_FusedConv2D_Add", ng_fused_op_0, ng_input_add);
} else if (fused_ops[1] == "Relu") {
ng_fused_op_1 = ConstructNgNode<opset::Relu>(
op->name() + "_FusedConv2D_Relu", ng_fused_op_0);
} else if (fused_ops[1] == "Relu6") {
ng_fused_op_1 = ConstructNgNode<opset::Clamp>(
op->name() + "_FusedConv2D_Relu6", ng_fused_op_0, 0, 6);
} else {
// shouldn't come here
return errors::Unimplemented("Unsupported _FusedConv2D " +
absl::StrJoin(fused_ops, ","));
}
}
ng::Output<ng::Node> ng_fused_op_2;
if (fused_ops.size() == 2) {
NCHWtoNHWC(op->name(), is_nhwc, ng_fused_op_1);
SaveNgOp(ng_op_map, op->name(), ng_fused_op_1);
return Status::OK();
} else {
if (fused_ops[2] == "Relu") {
// Relu
ng_fused_op_2 = ConstructNgNode<opset::Relu>(
op->name() + "_FusedConv2D_Relu", ng_fused_op_1);
NCHWtoNHWC(op->name(), is_nhwc, ng_fused_op_2);
SaveNgOp(ng_op_map, op->name(), ng_fused_op_2);
}
}
return Status::OK();
}