void SgMKLDNNConvOperator::Forward()

in src/operator/subgraph/mkldnn/mkldnn_conv.cc [209:393]


void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
                                   const std::vector<NDArray> &inputs,
                                   const std::vector<OpReqType> &req,
                                   const std::vector<NDArray> &outputs) {
  auto &full_conv_param = param_.full_conv_param;
  auto &mkldnn_param = param_.full_conv_param.mkldnn_param;
  auto &conv_param = param_.full_conv_param.conv_param;
  auto bn_param = param_.bn_param.get();
  size_t input_size =
      2 + (conv_param.no_bias ? 0 : 1) + (mkldnn_param.with_bn ? 4 : 0) +
      (mkldnn_param.with_sum ? 1 : 0) +
      (mkldnn_param.quantized ? 2 + (full_conv_param.mkldnn_param.with_sum ? 2 : 0) : 0);
  CHECK_EQ(inputs.size(), input_size);
  size_t idx = 0;

  auto in_data = idx++;
  auto in_weight = idx++;
  auto in_bias = conv_param.no_bias ? 0 : (idx++);
  auto in_gamma = mkldnn_param.with_bn ? (idx++) : 0;
  auto in_beta = mkldnn_param.with_bn ? (idx++) : 0;
  auto in_mean = mkldnn_param.with_bn ? (idx++) : 0;
  auto in_var = mkldnn_param.with_bn ? (idx++) : 0;
  auto in_sum = mkldnn_param.with_sum ? (idx++) : 0;
  float data_min =
      mkldnn_param.quantized ? inputs[idx++].data().dptr<float>()[0] : 0.0;
  float data_max =
      mkldnn_param.quantized ? inputs[idx++].data().dptr<float>()[0] : 0.0;
  float sum_min = (mkldnn_param.with_sum && mkldnn_param.quantized)
                      ? inputs[idx++].data().dptr<float>()[0]
                      : 0.0;
  float sum_max = (mkldnn_param.with_sum && mkldnn_param.quantized)
                      ? inputs[idx++].data().dptr<float>()[0]
                      : 0.0;
  float *out_min_ptr =
      mkldnn_param.quantized ? outputs[kMin].data().dptr<float>() : nullptr;
  float *out_max_ptr =
      mkldnn_param.quantized ? outputs[kMax].data().dptr<float>() : nullptr;
  CHECK_EQ(input_size, idx);
  bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias;
  NDArray data = inputs[in_data];
  NDArray output = mkldnn_param.with_sum ? inputs[in_sum] : outputs[kOut];

  // Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed.
  if (mkldnn_param.with_sum) {
    if (!initalized_) {
      auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
      auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
      // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option,
      // which make check (req[kOut] == kWriteInplace) useless.
      if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
        inplace_ = true;
      }
    }
    if (!inplace_) {
      auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
      const_cast<NDArray &>(outputs[kOut]).CopyFrom(*in_mkl_mem);
      output = NDArray(outputs[kOut].GetMKLDNNData());
    }
  }

  // Check input change
  // TODO(zhennan): Only update cached_* changed.
  if (initalized_) {
    if (mkldnn_param.with_bn) {
      if (weight_ver_ != inputs[in_weight].version() ||
          ((!conv_param.no_bias) && bias_ver_ != inputs[in_bias].version())) {
        initalized_ = false;
      }
    }
    if (initalized_ && mkldnn_param.quantized) {
      if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
          cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
          weight_ver_ != inputs[in_weight].version() ||
          ((!conv_param.no_bias) && bias_ver_ != inputs[in_bias].version())) {
        initalized_ = false;
      }
    }
  }
  bool post_requantize = false;
  if (mkldnn_param.quantized) {
    if (mkldnn_param.min_calib_range.has_value() &&
        mkldnn_param.max_calib_range.has_value()) {
      post_requantize = true;
      mkldnn_param.weight_channelwise_scale = true;
      *out_min_ptr = mkldnn_param.min_calib_range.value();
      *out_max_ptr = mkldnn_param.max_calib_range.value();
    } else {
      mkldnn_param.weight_channelwise_scale = false;
    }
  }

  if (!initalized_) {
    cached_data_min_ = data_min;
    cached_data_max_ = data_max;
    cached_sum_min_ = sum_min;
    cached_sum_max_ = sum_max;
    full_conv_param.sum_scale = 1.0;
    cached_weight_ = inputs[in_weight].Reorder2Default();
    weight_ver_ = inputs[in_weight].version();
    if (!conv_param.no_bias) {
      cached_bias_ = inputs[in_bias].Reorder2Default();
      bias_ver_ = inputs[in_bias].version();
    } else {
      cached_bias_ = NDArray();
    }

    // Update weight and bias after bn fusion.
    if (mkldnn_param.with_bn) {
      CHECK_EQ(inputs[in_weight].dtype(), inputs[in_gamma].dtype());
      CHECK_EQ(inputs[in_weight].dtype(), inputs[in_beta].dtype());
      CHECK_EQ(inputs[in_weight].dtype(), inputs[in_var].dtype());
      MSHADOW_REAL_TYPE_SWITCH(inputs[in_weight].dtype(), DType, {
        UpdateConvWeightBias<DType>(&cached_weight_, &cached_bias_,
                                    conv_param.no_bias, inputs[in_gamma],
                                    inputs[in_beta], inputs[in_mean],
                                    inputs[in_var], bn_param);
      });
    }
    // Quantize weight and bias.
    if (mkldnn_param.quantized) {
      CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
      auto data_range = (data.dtype() == mshadow::kInt8) ? int8_range : uint8_range;
      float data_scale = data_range / MaxAbs(cached_data_min_, cached_data_max_);
      MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
        QuantizeConvWeightBias<DType>(&cached_weight_, &cached_bias_,
                                      has_bias, data_scale,
                                      mkldnn_param.weight_channelwise_scale,
                                      &weight_scales_);
      });
      // Collect scale.
      size_t channel = cached_weight_.shape()[0];
      float sum_in_scale = 1.0;
      float out_range;
      float quantized_out_range;
      float output_scale;
      if (cached_data_min_ < 0.0) {
        // TODO(zhennan): Support int8 input when mkldnn supports.
        LOG(FATAL) << "Can't handle negetive value for QuantizeData";
      }
      if (mkldnn_param.with_sum) {
        auto quantized_sum_range = cached_sum_min_ < 0 ? int8_range : uint8_range;
        sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, cached_sum_max_);
      }
      if (post_requantize) {
        quantized_out_range =
            IsOutputUInt8(mkldnn_param) ? uint8_range : int8_range;
        out_range = MaxAbs(*out_min_ptr, *out_max_ptr);
        output_scale = quantized_out_range / out_range;
        full_conv_param.requantize_scales.resize(channel);
        for (size_t c = 0; c < channel; c++) {
          auto weight_scale = mkldnn_param.weight_channelwise_scale
                                  ? weight_scales_[c]
                                  : weight_scales_[0];
          full_conv_param.requantize_scales[c] =
              output_scale / data_scale / weight_scale;
        }
      } else {
        output_scale = data_scale * weight_scales_[0];
        full_conv_param.requantize_scales.resize(0);
      }
      if (mkldnn_param.with_sum)
        full_conv_param.sum_scale = output_scale / sum_in_scale;
    }
    fwd_.reset(new MKLDNNConvForward(
        full_conv_param, ctx.is_train, data, cached_weight_,
        has_bias ? &cached_bias_ : nullptr, output));
  }
  initalized_ = true;
  std::vector<NDArray> new_inputs;
  std::vector<OpReqType> new_req;
  if (has_bias) {
    new_inputs = {data, cached_weight_, cached_bias_};
    new_req = {req[in_data], req[in_weight], req[in_bias]};
  } else {
    new_inputs = {data, cached_weight_};
    new_req = {req[in_data], req[in_weight]};
  }
  ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd_.get(), new_inputs,
                                new_req, {output});

  if (mkldnn_param.with_sum) {
    auto out = const_cast<NDArray &>(outputs[kOut]);
    out.UpdateMKLDNNMemDesc();
  }
}