in onnxruntime/core/providers/cpu/quantization/qlinearconv.cc [494:893]
Status QLinearConv<ActType>::Compute(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(InputTensors::IN_X);
const Tensor* W = is_W_packed_ ? nullptr : context->Input<Tensor>(InputTensors::IN_W);
const auto& W_shape = W ? W->Shape() : W_shape_;
const bool is_W_signed = (W != nullptr) ? W->IsDataType<int8_t>() : is_W_signed_;
const int64_t N = X->Shape()[0];
const int64_t M = W_shape[0];
ActType X_zero_point_value;
ActType Y_zero_point_value;
uint8_t W_zero_point_value;
ComputeOffset(context, M, X_zero_point_value, Y_zero_point_value, W_zero_point_value);
std::vector<float> output_scales = ComputeOutputScale(context, M);
const Tensor* B = context->Input<Tensor>(InputTensors::IN_BIAS);
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W_shape, channels_last_));
TensorShapeVector kernel_shape;
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W_shape, kernel_shape));
const size_t kernel_rank = kernel_shape.size();
ConvPadVector pads(conv_attrs_.pads);
if (pads.empty()) {
pads.resize(kernel_rank * 2, 0);
}
TensorShapeVector dilations(conv_attrs_.dilations);
if (dilations.empty()) {
dilations.resize(kernel_rank, 1);
}
TensorShapeVector strides(conv_attrs_.strides);
if (strides.empty()) {
strides.resize(kernel_rank, 1);
}
const int64_t C = X->Shape()[channels_last_ ? 1 + kernel_rank : 1];
const size_t spatial_dim_start = channels_last_ ? 1 : 2;
const size_t spatial_dim_end = spatial_dim_start + kernel_rank;
TensorShapeVector Y_dims({N});
if (!channels_last_) {
Y_dims.push_back(M);
}
TensorShape input_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end);
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims));
if (channels_last_) {
Y_dims.push_back(M);
}
Tensor* Y = context->Output(OutputTensors::OUT_Y, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(spatial_dim_start, spatial_dim_end);
// Bail out early if one of the dimensions is zero.
if (Y->Shape().Size() == 0) {
return Status::OK();
}
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
// Handle the case of a dynamic weight filter.
BufferUniquePtr reordered_W_buffer;
uint8_t* reordered_W = nullptr;
if (!packed_W_buffer_) {
if (W == nullptr) {
// Weight was constant and reordered.
reordered_W = static_cast<uint8_t*>(reordered_W_buffer_.get());
} else {
// Weight tensor was not constant or prepacking is disabled.
reordered_W = static_cast<uint8_t*>(alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * W_shape.Size()));
reordered_W_buffer = BufferUniquePtr(reordered_W, BufferDeleter(alloc));
ReorderFilter(
static_cast<const uint8_t*>(W->DataRaw()),
reordered_W,
static_cast<size_t>(M),
static_cast<size_t>(W_shape[1]),
static_cast<size_t>(kernel_size));
}
}
int64_t group_count = conv_attrs_.group;
int64_t group_input_channels = W_shape[1];
int64_t group_output_channels = M / group_count;
// Test for depthwise convolution.
const bool is_depthwise_conv = ((is_symmetric_conv_ || reordered_W != nullptr) && group_input_channels == 1 && group_output_channels == 1);
if (is_depthwise_conv) {
// Update the input and output channels to the number of groups in order to
// reuse as much of the below standard convolution path.
group_input_channels = group_count;
group_output_channels = group_count;
group_count = 1;
}
const int64_t X_offset = C * input_image_size;
const int64_t Y_offset = M * output_image_size;
const int64_t kernel_dim = group_input_channels * kernel_size;
const int64_t col_buffer_size = kernel_dim * output_image_size;
// Use an intermediate int32_t buffer for the GEMM computation before
// requantizing to the output type.
//
// This buffer is not needed for the symmetric convolution path as requantization
// is fused with the GEMM compuation.
BufferUniquePtr gemm_output_buffer;
if (!is_symmetric_conv_) {
auto* gemm_output_data = alloc->Alloc(SafeInt<size_t>(sizeof(int32_t)) * Y_offset);
gemm_output_buffer = BufferUniquePtr(gemm_output_data, BufferDeleter(alloc));
}
const auto* Xdata = X->template Data<ActType>();
const auto* Bdata = B != nullptr ? B->template Data<int32_t>() : nullptr;
auto* Ydata = Y->template MutableData<ActType>();
BufferUniquePtr transpose_input_buffer;
BufferUniquePtr transpose_output_buffer;
// Allocate temporary buffers for transposing to channels last format.
if (!channels_last_) {
auto* transpose_input = alloc->Alloc(SafeInt<size_t>(sizeof(ActType)) * (X_offset + MLAS_SYMM_QGEMM_BUF_OVERRUN));
transpose_input_buffer = BufferUniquePtr(transpose_input, BufferDeleter(alloc));
auto* transpose_output = alloc->Alloc(SafeInt<size_t>(sizeof(ActType)) * Y_offset);
transpose_output_buffer = BufferUniquePtr(transpose_output, BufferDeleter(alloc));
}
BufferUniquePtr col_buffer;
BufferUniquePtr indirection_buffer;
std::vector<ActType> padding_data;
bool use_indirection_buffer = false;
if (is_depthwise_conv) {
use_indirection_buffer = true;
} else if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding()) {
if (is_symmetric_conv_) {
use_indirection_buffer = true;
} else {
// Pointwise convolutions can use the original input tensor in place,
// otherwise a temporary buffer is required for the im2col transform.
int64_t group_col_buffer_size = (kernel_rank > 2) ? group_count * col_buffer_size : col_buffer_size;
group_col_buffer_size += MLAS_SYMM_QGEMM_BUF_OVERRUN;
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(ActType)) * group_col_buffer_size);
col_buffer = BufferUniquePtr(col_data, BufferDeleter(alloc));
}
}
if (use_indirection_buffer) {
// Allocate indirection buffer pointers and prepare a padding vector for
// the im2col transform.
auto* indirection_data = alloc->Alloc(SafeInt<size_t>(sizeof(const ActType*)) * kernel_size * output_image_size);
indirection_buffer = BufferUniquePtr(indirection_data, BufferDeleter(alloc));
padding_data.resize(static_cast<size_t>(C), X_zero_point_value);
}
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
#if defined(_M_ARM64) || defined(__aarch64__)
int32_t task_count = (output_image_size + (GEMM_KERNEL_STRIDE_M - 1)) / GEMM_KERNEL_STRIDE_M;
#else
int32_t task_count = ComputeTaskCount(output_image_size, group_output_channels, kernel_dim);
task_count = std::min(task_count, concurrency::ThreadPool::DegreeOfParallelism(thread_pool));
#endif
for (int64_t image_id = 0; image_id < N; ++image_id) {
const auto* input_data = Xdata;
auto* output_data = Ydata;
if (!channels_last_) {
// Transpose the input from channels first (NCHW) to channels last (NHWC).
MlasTranspose(
Xdata,
static_cast<ActType*>(transpose_input_buffer.get()),
static_cast<size_t>(C),
static_cast<size_t>(input_image_size));
input_data = static_cast<ActType*>(transpose_input_buffer.get());
output_data = static_cast<ActType*>(transpose_output_buffer.get());
}
// Threaded implementation of ND convolution is not yet supported, so
// prepare all im2col transformations here.
if (col_buffer && kernel_rank > 2) {
for (int64_t group_id = 0; group_id < group_count; ++group_id) {
math::Im2col<ActType, StorageOrder::NHWC>()(
input_data + group_id * group_input_channels,
group_input_channels,
C,
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int64_t>(kernel_rank),
static_cast<ActType*>(col_buffer.get()) + group_id * col_buffer_size,
X_zero_point_value);
}
}
auto conv_worker = [&](ptrdiff_t batch) {
#if defined(_M_ARM64) || defined(__aarch64__)
int64_t output_start = batch * GEMM_KERNEL_STRIDE_M;
int64_t output_count = std::min((int64_t)GEMM_KERNEL_STRIDE_M, output_image_size - output_start);
#else
auto work = concurrency::ThreadPool::PartitionWork(batch, task_count, static_cast<ptrdiff_t>(output_image_size));
int64_t output_start = static_cast<int64_t>(work.start);
int64_t output_count = static_cast<int64_t>(work.end) - work.start;
#endif
ActType const** worker_indirection_buffer = nullptr;
if (indirection_buffer) {
worker_indirection_buffer = static_cast<ActType const**>(indirection_buffer.get()) + output_start * kernel_size;
math::Im2col<ActType, StorageOrder::NHWC>()(
input_data,
C,
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<ptrdiff_t>(kernel_rank),
output_start,
output_count,
worker_indirection_buffer,
padding_data.data());
}
auto* worker_output = output_data + output_start * M;
if (is_symmetric_conv_) {
MLAS_CONV_SYM_PARAMS conv_params = {};
if (worker_indirection_buffer) {
conv_params.InputIndirection = reinterpret_cast<void const**>(worker_indirection_buffer);
} else {
conv_params.InputDirect = input_data + output_start * C;
}
conv_params.Filter = packed_W_buffer_.get();
conv_params.Output = worker_output;
conv_params.InputChannels = static_cast<size_t>(C);
conv_params.OutputChannels = static_cast<size_t>(M);
conv_params.OutputCount = static_cast<size_t>(output_count);
conv_params.KernelSize = static_cast<size_t>(kernel_size);
conv_params.Bias = column_sums_.data();
conv_params.Scale = output_scales.data();
conv_params.PerChannelScale = output_scales.size() > 1;
conv_params.OutputZeroPoint = Y_zero_point_value;
conv_params.InputIsSigned = std::is_signed<ActType>::value;
if (is_depthwise_conv) {
MlasConvSymDepthwise(conv_params);
} else {
MlasConvSym(conv_params);
}
return;
}
auto* worker_gemm_output = static_cast<int32_t*>(gemm_output_buffer.get()) + output_start * M;
if (is_depthwise_conv) {
MlasConvDepthwise(
reinterpret_cast<const void* const*>(worker_indirection_buffer),
X_zero_point_value,
std::is_signed<ActType>::value,
reinterpret_cast<const void* const*>(reordered_W),
W_zero_point_value,
is_W_signed,
worker_gemm_output,
static_cast<size_t>(M),
static_cast<size_t>(output_count),
static_cast<size_t>(kernel_size));
} else {
for (int64_t group_id = 0; group_id < group_count; ++group_id) {
// Prepare the im2col transformation or use the input buffer directly for
// pointwise convolutions.
const auto* group_input_data = input_data + group_id * group_input_channels;
const uint8_t* AData;
size_t lda;
if (col_buffer) {
auto* worker_col_buffer = static_cast<ActType*>(col_buffer.get()) + output_start * kernel_dim;
if (kernel_rank == 2) {
math::Im2col<ActType, StorageOrder::NHWC>()(
group_input_data,
group_input_channels,
C,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
strides[0],
strides[1],
output_shape[1],
output_start,
output_count,
worker_col_buffer,
X_zero_point_value);
} else if (kernel_rank == 1) {
math::Im2col<ActType, StorageOrder::NHWC>()(
group_input_data,
group_input_channels,
C,
1,
input_shape[0],
1,
kernel_shape[0],
1,
dilations[0],
0,
pads[0],
1,
strides[0],
output_shape[0],
output_start,
output_count,
worker_col_buffer,
X_zero_point_value);
} else {
// Use the im2col buffer prepared outside the thread, indexed by group.
worker_col_buffer += group_id * col_buffer_size;
}
AData = reinterpret_cast<const uint8_t*>(worker_col_buffer);
lda = static_cast<size_t>(kernel_dim);
} else {
AData = reinterpret_cast<const uint8_t*>(group_input_data + output_start * C);
lda = static_cast<size_t>(C);
}
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(output_count);
gemm_shape.N = static_cast<size_t>(group_output_channels);
gemm_shape.K = static_cast<size_t>(kernel_dim);
gemm_shape.AIsSigned = std::is_signed<ActType>::value;
gemm_shape.BIsSigned = is_W_signed;
if (is_symmetric_gemm_) {
MLAS_SYMM_QGEMM_DATA_PARAMS symm_gemm;
symm_gemm.A = AData;
symm_gemm.lda = lda;
symm_gemm.C = worker_gemm_output + group_id * group_output_channels;
symm_gemm.ldc = static_cast<size_t>(M);
symm_gemm.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
MlasSymmQgemmBatch(gemm_shape, &symm_gemm, 1, nullptr);
} else {
MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
gemm_params.ZeroPointA = static_cast<uint8_t>(X_zero_point_value);
gemm_params.A = AData;
gemm_params.lda = lda;
if (packed_W_buffer_) {
gemm_params.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
gemm_params.BIsPacked = true;
} else {
gemm_params.B = reordered_W + group_id * group_output_channels,
gemm_params.ldb = static_cast<size_t>(M);
}
gemm_params.ZeroPointB = &W_zero_point_value;
gemm_params.C = worker_gemm_output + group_id * group_output_channels;
gemm_params.ldc = static_cast<size_t>(M);
MlasGemm(gemm_shape, gemm_params, nullptr);
}
}
}
MlasRequantizeOutput(
worker_gemm_output,
static_cast<size_t>(M),
worker_output,
static_cast<size_t>(M),
Bdata,
output_scales.data(),
output_scales.size() > 1,
Y_zero_point_value,
0,
0,
static_cast<size_t>(output_count),
static_cast<size_t>(M));
};
concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, task_count, conv_worker);
if (!channels_last_) {
// Transpose the output from channels last (NHWC) to channels first (NCHW).
MlasTranspose(
output_data,
Ydata,
static_cast<size_t>(output_image_size),
static_cast<size_t>(M));
}
Xdata += X_offset;
Ydata += Y_offset;
}
return Status::OK();
}