static Status TranslateFusedConv2DOp()

in ngraph_bridge/ngraph_builder.cc [1457:1650]


static Status TranslateFusedConv2DOp(const Node* op,
                                     const std::vector<const Tensor*>&,
                                     Builder::OpMap& ng_op_map) {
  int num_args;
  TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "num_args", &num_args));
  NGRAPH_VLOG(3) << "num_args: " << num_args;

  std::vector<string> fused_ops;
  TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "fused_ops", &fused_ops));

  std::string tf_data_format;
  TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "data_format", &tf_data_format));
  bool is_nhwc = (tf_data_format == "NHWC");

  auto CreateNgConv = [&](ng::Output<ng::Node>& ng_input,
                          ng::Output<ng::Node>& ng_filter,
                          ng::Output<ng::Node>& ng_conv) {
    std::vector<int32> tf_strides;
    std::vector<int32> tf_dilations;
    std::string tf_padding_type;
    TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "strides", &tf_strides));
    TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "dilations", &tf_dilations));
    TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "padding", &tf_padding_type));

    if (tf_data_format != "NHWC" && tf_data_format != "NCHW") {
      return errors::InvalidArgument(
          "Conv2D data format is neither NHWC nor NCHW");
    }

    // TF Kernel Test Checks
    // Strides in the batch and depth dimension is not supported
    if (tf_strides[0] != 1 || tf_strides[is_nhwc ? 3 : 1] != 1) {
      return errors::InvalidArgument(
          "Strides in batch and depth dimensions is not supported: ",
          op->type_string());
    }

    NGRAPH_VLOG(3) << ng::join(tf_strides);
    NGRAPH_VLOG(3) << ng::join(tf_dilations);
    NGRAPH_VLOG(3) << tf_padding_type;
    NGRAPH_VLOG(3) << tf_data_format;

    ng::Strides ng_strides(2);
    ng::Strides ng_dilations(2);
    ng::Shape ng_image_shape(2);
    ng::Shape ng_kernel_shape(2);

    NHWCtoHW(is_nhwc, tf_strides, ng_strides);
    NHWCtoHW(is_nhwc, ng_input.get_shape(), ng_image_shape);
    NHWCtoHW(is_nhwc, tf_dilations, ng_dilations);
    NHWCtoNCHW(op->name(), is_nhwc, ng_input);

    NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides);
    NGRAPH_VLOG(3) << "ng_dilations: " << ng::join(ng_dilations);
    NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape);

    auto& ng_filter_shape = ng_filter.get_shape();
    ng_kernel_shape[0] = ng_filter_shape[0];
    ng_kernel_shape[1] = ng_filter_shape[1];
    Transpose<3, 2, 0, 1>(ng_filter);
    Builder::SetTracingInfo(op->name(), ng_filter);

    NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape);

    ng::CoordinateDiff ng_padding_below;
    ng::CoordinateDiff ng_padding_above;
    Builder::MakePadding(tf_padding_type, ng_image_shape, ng_kernel_shape,
                         ng_strides, ng_dilations, ng_padding_below,
                         ng_padding_above);

    ng_conv = ConstructNgNode<opset::Convolution>(
        op->name() + "_FusedConv2D_Conv", ng_input, ng_filter, ng_strides,
        ng_padding_below, ng_padding_above, ng_dilations);

    return Status::OK();
  };

  auto CreateNgBiasAdd = [&](ng::Output<ng::Node>& ng_conv,
                             ng::Output<ng::Node>& ng_bias,
                             ng::Output<ng::Node>& ng_bias_add) {
    auto ng_conv_shape = ng_conv.get_shape();
    auto ng_bias_shape = ng_bias.get_shape();
    if (ng_bias_shape.size() != 1) {
      return errors::InvalidArgument(
          "Bias argument to BiasAdd does not have one dimension");
    }

    std::vector<size_t> reshape_pattern_values(ng_conv_shape.size(), 1U);
    reshape_pattern_values[1] = ng_bias.get_shape().front();
    auto reshape_pattern = make_shared<opset::Constant>(
        ng::element::u64, ng::Shape{reshape_pattern_values.size()},
        reshape_pattern_values);
    auto ng_bias_reshaped = ConstructNgNode<opset::Reshape>(
        op->name(), ng_bias, reshape_pattern, false);

    ng_bias_add = ConstructNgNode<opset::Add>(
        op->name() + "_FusedConv2D_BiasAdd", ng_conv, ng_bias_reshaped);
    return Status::OK();
  };

  if (VecStrCmp(fused_ops, {"BiasAdd"}) ||
      VecStrCmp(fused_ops, {"BiasAdd", "Relu"}) ||
      VecStrCmp(fused_ops, {"BiasAdd", "Relu6"})) {
    if (num_args != 1) {
      return errors::InvalidArgument(
          "FusedConv2DBiasAdd has incompatible num_args");
    }
  } else if (VecStrCmp(fused_ops, {"BiasAdd", "Add", "Relu"}) ||
             VecStrCmp(fused_ops, {"BiasAdd", "Add"})) {
    if (num_args != 2) {
      return errors::InvalidArgument(
          "FusedConv2DBiasAddAdd has incompatible num_args");
    }
  } else if (VecStrCmp(fused_ops, {"FusedBatchNorm"}) ||
             VecStrCmp(fused_ops, {"FusedBatchNorm", "Relu"}) ||
             VecStrCmp(fused_ops, {"FusedBatchNorm", "Relu6"})) {
    if (num_args != 4) {
      return errors::InvalidArgument(
          "FusedConv2D with FusedBatchNorm has incompatible num_args");
    }
  } else {
    return errors::Unimplemented("Unsupported _FusedConv2D " +
                                 absl::StrJoin(fused_ops, ","));
  }

  // Conv2D
  ng::Output<ng::Node> ng_input, ng_filter, ng_conv;
  TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, ng_input));
  TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 1, ng_filter));
  TF_RETURN_IF_ERROR(CreateNgConv(ng_input, ng_filter, ng_conv));

  // BiasAdd or BatchNorm
  ng::Output<ng::Node> ng_fused_op_0, ng_bias, ng_scale, ng_offset, ng_mean,
      ng_variance;
  if (fused_ops[0] == "BiasAdd") {
    TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 2, ng_bias));
    TF_RETURN_IF_ERROR(CreateNgBiasAdd(ng_conv, ng_bias, ng_fused_op_0));
  } else if (fused_ops[0] == "FusedBatchNorm") {
    TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 2, ng_scale));
    TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 3, ng_offset));
    TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 4, ng_mean));
    TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 5, ng_variance));
    float tf_epsilon;
    TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "epsilon", &tf_epsilon));
    ng_fused_op_0 = ConstructNgNode<opset::BatchNormInference>(
        op->name() + "_FusedConv2D_BatchNorm", ng_conv, ng_scale, ng_offset,
        ng_mean, ng_variance, tf_epsilon);
  } else {
    // shouldn't come
    return errors::Unimplemented("Unsupported _FusedConv2D " +
                                 absl::StrJoin(fused_ops, ","));
  }

  ng::Output<ng::Node> ng_input_add, ng_fused_op_1;
  if (fused_ops.size() == 1) {
    NCHWtoNHWC(op->name(), is_nhwc, ng_fused_op_0);
    SaveNgOp(ng_op_map, op->name(), ng_fused_op_0);
    return Status::OK();
  } else {
    // Add or Relu or Relu6
    if (fused_ops[1] == "Add") {
      TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 3, ng_input_add));
      NHWCtoNCHW(op->name(), is_nhwc, ng_input_add);
      ng_fused_op_1 = ConstructNgNode<opset::Add>(
          op->name() + "_FusedConv2D_Add", ng_fused_op_0, ng_input_add);
    } else if (fused_ops[1] == "Relu") {
      ng_fused_op_1 = ConstructNgNode<opset::Relu>(
          op->name() + "_FusedConv2D_Relu", ng_fused_op_0);
    } else if (fused_ops[1] == "Relu6") {
      ng_fused_op_1 = ConstructNgNode<opset::Clamp>(
          op->name() + "_FusedConv2D_Relu6", ng_fused_op_0, 0, 6);
    } else {
      // shouldn't come here
      return errors::Unimplemented("Unsupported _FusedConv2D " +
                                   absl::StrJoin(fused_ops, ","));
    }
  }

  ng::Output<ng::Node> ng_fused_op_2;
  if (fused_ops.size() == 2) {
    NCHWtoNHWC(op->name(), is_nhwc, ng_fused_op_1);
    SaveNgOp(ng_op_map, op->name(), ng_fused_op_1);
    return Status::OK();
  } else {
    if (fused_ops[2] == "Relu") {
      // Relu
      ng_fused_op_2 = ConstructNgNode<opset::Relu>(
          op->name() + "_FusedConv2D_Relu", ng_fused_op_1);
      NCHWtoNHWC(op->name(), is_nhwc, ng_fused_op_2);
      SaveNgOp(ng_op_map, op->name(), ng_fused_op_2);
    }
  }
  return Status::OK();
}