in src/operator/subgraph/mkldnn/mkldnn_conv.cc [209:393]
void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
auto &full_conv_param = param_.full_conv_param;
auto &mkldnn_param = param_.full_conv_param.mkldnn_param;
auto &conv_param = param_.full_conv_param.conv_param;
auto bn_param = param_.bn_param.get();
size_t input_size =
2 + (conv_param.no_bias ? 0 : 1) + (mkldnn_param.with_bn ? 4 : 0) +
(mkldnn_param.with_sum ? 1 : 0) +
(mkldnn_param.quantized ? 2 + (full_conv_param.mkldnn_param.with_sum ? 2 : 0) : 0);
CHECK_EQ(inputs.size(), input_size);
size_t idx = 0;
auto in_data = idx++;
auto in_weight = idx++;
auto in_bias = conv_param.no_bias ? 0 : (idx++);
auto in_gamma = mkldnn_param.with_bn ? (idx++) : 0;
auto in_beta = mkldnn_param.with_bn ? (idx++) : 0;
auto in_mean = mkldnn_param.with_bn ? (idx++) : 0;
auto in_var = mkldnn_param.with_bn ? (idx++) : 0;
auto in_sum = mkldnn_param.with_sum ? (idx++) : 0;
float data_min =
mkldnn_param.quantized ? inputs[idx++].data().dptr<float>()[0] : 0.0;
float data_max =
mkldnn_param.quantized ? inputs[idx++].data().dptr<float>()[0] : 0.0;
float sum_min = (mkldnn_param.with_sum && mkldnn_param.quantized)
? inputs[idx++].data().dptr<float>()[0]
: 0.0;
float sum_max = (mkldnn_param.with_sum && mkldnn_param.quantized)
? inputs[idx++].data().dptr<float>()[0]
: 0.0;
float *out_min_ptr =
mkldnn_param.quantized ? outputs[kMin].data().dptr<float>() : nullptr;
float *out_max_ptr =
mkldnn_param.quantized ? outputs[kMax].data().dptr<float>() : nullptr;
CHECK_EQ(input_size, idx);
bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias;
NDArray data = inputs[in_data];
NDArray output = mkldnn_param.with_sum ? inputs[in_sum] : outputs[kOut];
// Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed.
if (mkldnn_param.with_sum) {
if (!initalized_) {
auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
// TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option,
// which make check (req[kOut] == kWriteInplace) useless.
if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
inplace_ = true;
}
}
if (!inplace_) {
auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
const_cast<NDArray &>(outputs[kOut]).CopyFrom(*in_mkl_mem);
output = NDArray(outputs[kOut].GetMKLDNNData());
}
}
// Check input change
// TODO(zhennan): Only update cached_* changed.
if (initalized_) {
if (mkldnn_param.with_bn) {
if (weight_ver_ != inputs[in_weight].version() ||
((!conv_param.no_bias) && bias_ver_ != inputs[in_bias].version())) {
initalized_ = false;
}
}
if (initalized_ && mkldnn_param.quantized) {
if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
weight_ver_ != inputs[in_weight].version() ||
((!conv_param.no_bias) && bias_ver_ != inputs[in_bias].version())) {
initalized_ = false;
}
}
}
bool post_requantize = false;
if (mkldnn_param.quantized) {
if (mkldnn_param.min_calib_range.has_value() &&
mkldnn_param.max_calib_range.has_value()) {
post_requantize = true;
mkldnn_param.weight_channelwise_scale = true;
*out_min_ptr = mkldnn_param.min_calib_range.value();
*out_max_ptr = mkldnn_param.max_calib_range.value();
} else {
mkldnn_param.weight_channelwise_scale = false;
}
}
if (!initalized_) {
cached_data_min_ = data_min;
cached_data_max_ = data_max;
cached_sum_min_ = sum_min;
cached_sum_max_ = sum_max;
full_conv_param.sum_scale = 1.0;
cached_weight_ = inputs[in_weight].Reorder2Default();
weight_ver_ = inputs[in_weight].version();
if (!conv_param.no_bias) {
cached_bias_ = inputs[in_bias].Reorder2Default();
bias_ver_ = inputs[in_bias].version();
} else {
cached_bias_ = NDArray();
}
// Update weight and bias after bn fusion.
if (mkldnn_param.with_bn) {
CHECK_EQ(inputs[in_weight].dtype(), inputs[in_gamma].dtype());
CHECK_EQ(inputs[in_weight].dtype(), inputs[in_beta].dtype());
CHECK_EQ(inputs[in_weight].dtype(), inputs[in_var].dtype());
MSHADOW_REAL_TYPE_SWITCH(inputs[in_weight].dtype(), DType, {
UpdateConvWeightBias<DType>(&cached_weight_, &cached_bias_,
conv_param.no_bias, inputs[in_gamma],
inputs[in_beta], inputs[in_mean],
inputs[in_var], bn_param);
});
}
// Quantize weight and bias.
if (mkldnn_param.quantized) {
CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
auto data_range = (data.dtype() == mshadow::kInt8) ? int8_range : uint8_range;
float data_scale = data_range / MaxAbs(cached_data_min_, cached_data_max_);
MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
QuantizeConvWeightBias<DType>(&cached_weight_, &cached_bias_,
has_bias, data_scale,
mkldnn_param.weight_channelwise_scale,
&weight_scales_);
});
// Collect scale.
size_t channel = cached_weight_.shape()[0];
float sum_in_scale = 1.0;
float out_range;
float quantized_out_range;
float output_scale;
if (cached_data_min_ < 0.0) {
// TODO(zhennan): Support int8 input when mkldnn supports.
LOG(FATAL) << "Can't handle negetive value for QuantizeData";
}
if (mkldnn_param.with_sum) {
auto quantized_sum_range = cached_sum_min_ < 0 ? int8_range : uint8_range;
sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, cached_sum_max_);
}
if (post_requantize) {
quantized_out_range =
IsOutputUInt8(mkldnn_param) ? uint8_range : int8_range;
out_range = MaxAbs(*out_min_ptr, *out_max_ptr);
output_scale = quantized_out_range / out_range;
full_conv_param.requantize_scales.resize(channel);
for (size_t c = 0; c < channel; c++) {
auto weight_scale = mkldnn_param.weight_channelwise_scale
? weight_scales_[c]
: weight_scales_[0];
full_conv_param.requantize_scales[c] =
output_scale / data_scale / weight_scale;
}
} else {
output_scale = data_scale * weight_scales_[0];
full_conv_param.requantize_scales.resize(0);
}
if (mkldnn_param.with_sum)
full_conv_param.sum_scale = output_scale / sum_in_scale;
}
fwd_.reset(new MKLDNNConvForward(
full_conv_param, ctx.is_train, data, cached_weight_,
has_bias ? &cached_bias_ : nullptr, output));
}
initalized_ = true;
std::vector<NDArray> new_inputs;
std::vector<OpReqType> new_req;
if (has_bias) {
new_inputs = {data, cached_weight_, cached_bias_};
new_req = {req[in_data], req[in_weight], req[in_bias]};
} else {
new_inputs = {data, cached_weight_};
new_req = {req[in_data], req[in_weight]};
}
ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd_.get(), new_inputs,
new_req, {output});
if (mkldnn_param.with_sum) {
auto out = const_cast<NDArray &>(outputs[kOut]);
out.UpdateMKLDNNMemDesc();
}
}