Status QLinearConv::Compute()

in onnxruntime/core/providers/cpu/quantization/qlinearconv.cc [494:893]


Status QLinearConv<ActType>::Compute(OpKernelContext* context) const {
  const Tensor* X = context->Input<Tensor>(InputTensors::IN_X);
  const Tensor* W = is_W_packed_ ? nullptr : context->Input<Tensor>(InputTensors::IN_W);
  const auto& W_shape = W ? W->Shape() : W_shape_;
  const bool is_W_signed = (W != nullptr) ? W->IsDataType<int8_t>() : is_W_signed_;

  const int64_t N = X->Shape()[0];
  const int64_t M = W_shape[0];

  ActType X_zero_point_value;
  ActType Y_zero_point_value;
  uint8_t W_zero_point_value;
  ComputeOffset(context, M, X_zero_point_value, Y_zero_point_value, W_zero_point_value);
  std::vector<float> output_scales = ComputeOutputScale(context, M);

  const Tensor* B = context->Input<Tensor>(InputTensors::IN_BIAS);

  ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W_shape, channels_last_));

  TensorShapeVector kernel_shape;
  ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W_shape, kernel_shape));

  const size_t kernel_rank = kernel_shape.size();

  ConvPadVector pads(conv_attrs_.pads);
  if (pads.empty()) {
    pads.resize(kernel_rank * 2, 0);
  }
  TensorShapeVector dilations(conv_attrs_.dilations);
  if (dilations.empty()) {
    dilations.resize(kernel_rank, 1);
  }
  TensorShapeVector strides(conv_attrs_.strides);
  if (strides.empty()) {
    strides.resize(kernel_rank, 1);
  }

  const int64_t C = X->Shape()[channels_last_ ? 1 + kernel_rank : 1];
  const size_t spatial_dim_start = channels_last_ ? 1 : 2;
  const size_t spatial_dim_end = spatial_dim_start + kernel_rank;

  TensorShapeVector Y_dims({N});
  if (!channels_last_) {
    Y_dims.push_back(M);
  }
  TensorShape input_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end);
  ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims));
  if (channels_last_) {
    Y_dims.push_back(M);
  }
  Tensor* Y = context->Output(OutputTensors::OUT_Y, TensorShape(Y_dims));
  TensorShape output_shape = Y->Shape().Slice(spatial_dim_start, spatial_dim_end);

  // Bail out early if one of the dimensions is zero.
  if (Y->Shape().Size() == 0) {
    return Status::OK();
  }

  const int64_t input_image_size = input_shape.Size();
  const int64_t output_image_size = output_shape.Size();
  const int64_t kernel_size = TensorShape(kernel_shape).Size();

  AllocatorPtr alloc;
  ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));

  // Handle the case of a dynamic weight filter.
  BufferUniquePtr reordered_W_buffer;
  uint8_t* reordered_W = nullptr;
  if (!packed_W_buffer_) {
    if (W == nullptr) {
      // Weight was constant and reordered.
      reordered_W = static_cast<uint8_t*>(reordered_W_buffer_.get());
    } else {
      // Weight tensor was not constant or prepacking is disabled.
      reordered_W = static_cast<uint8_t*>(alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * W_shape.Size()));
      reordered_W_buffer = BufferUniquePtr(reordered_W, BufferDeleter(alloc));
      ReorderFilter(
          static_cast<const uint8_t*>(W->DataRaw()),
          reordered_W,
          static_cast<size_t>(M),
          static_cast<size_t>(W_shape[1]),
          static_cast<size_t>(kernel_size));
    }
  }

  int64_t group_count = conv_attrs_.group;
  int64_t group_input_channels = W_shape[1];
  int64_t group_output_channels = M / group_count;

  // Test for depthwise convolution.
  const bool is_depthwise_conv = ((is_symmetric_conv_ || reordered_W != nullptr) && group_input_channels == 1 && group_output_channels == 1);
  if (is_depthwise_conv) {
    // Update the input and output channels to the number of groups in order to
    // reuse as much of the below standard convolution path.
    group_input_channels = group_count;
    group_output_channels = group_count;
    group_count = 1;
  }

  const int64_t X_offset = C * input_image_size;
  const int64_t Y_offset = M * output_image_size;
  const int64_t kernel_dim = group_input_channels * kernel_size;
  const int64_t col_buffer_size = kernel_dim * output_image_size;

  // Use an intermediate int32_t buffer for the GEMM computation before
  // requantizing to the output type.
  //
  // This buffer is not needed for the symmetric convolution path as requantization
  // is fused with the GEMM compuation.
  BufferUniquePtr gemm_output_buffer;
  if (!is_symmetric_conv_) {
    auto* gemm_output_data = alloc->Alloc(SafeInt<size_t>(sizeof(int32_t)) * Y_offset);
    gemm_output_buffer = BufferUniquePtr(gemm_output_data, BufferDeleter(alloc));
  }

  const auto* Xdata = X->template Data<ActType>();
  const auto* Bdata = B != nullptr ? B->template Data<int32_t>() : nullptr;
  auto* Ydata = Y->template MutableData<ActType>();

  BufferUniquePtr transpose_input_buffer;
  BufferUniquePtr transpose_output_buffer;

  // Allocate temporary buffers for transposing to channels last format.
  if (!channels_last_) {
    auto* transpose_input = alloc->Alloc(SafeInt<size_t>(sizeof(ActType)) * (X_offset + MLAS_SYMM_QGEMM_BUF_OVERRUN));
    transpose_input_buffer = BufferUniquePtr(transpose_input, BufferDeleter(alloc));
    auto* transpose_output = alloc->Alloc(SafeInt<size_t>(sizeof(ActType)) * Y_offset);
    transpose_output_buffer = BufferUniquePtr(transpose_output, BufferDeleter(alloc));
  }

  BufferUniquePtr col_buffer;
  BufferUniquePtr indirection_buffer;
  std::vector<ActType> padding_data;

  bool use_indirection_buffer = false;
  if (is_depthwise_conv) {
    use_indirection_buffer = true;
  } else if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding()) {
    if (is_symmetric_conv_) {
      use_indirection_buffer = true;
    } else {
      // Pointwise convolutions can use the original input tensor in place,
      // otherwise a temporary buffer is required for the im2col transform.
      int64_t group_col_buffer_size = (kernel_rank > 2) ? group_count * col_buffer_size : col_buffer_size;
      group_col_buffer_size += MLAS_SYMM_QGEMM_BUF_OVERRUN;
      auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(ActType)) * group_col_buffer_size);
      col_buffer = BufferUniquePtr(col_data, BufferDeleter(alloc));
    }
  }
  if (use_indirection_buffer) {
    // Allocate indirection buffer pointers and prepare a padding vector for
    // the im2col transform.
    auto* indirection_data = alloc->Alloc(SafeInt<size_t>(sizeof(const ActType*)) * kernel_size * output_image_size);
    indirection_buffer = BufferUniquePtr(indirection_data, BufferDeleter(alloc));
    padding_data.resize(static_cast<size_t>(C), X_zero_point_value);
  }

  concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
#if defined(_M_ARM64) || defined(__aarch64__)
  int32_t task_count = (output_image_size + (GEMM_KERNEL_STRIDE_M - 1)) / GEMM_KERNEL_STRIDE_M;
#else
  int32_t task_count = ComputeTaskCount(output_image_size, group_output_channels, kernel_dim);
  task_count = std::min(task_count, concurrency::ThreadPool::DegreeOfParallelism(thread_pool));
#endif

  for (int64_t image_id = 0; image_id < N; ++image_id) {
    const auto* input_data = Xdata;
    auto* output_data = Ydata;

    if (!channels_last_) {
      // Transpose the input from channels first (NCHW) to channels last (NHWC).
      MlasTranspose(
          Xdata,
          static_cast<ActType*>(transpose_input_buffer.get()),
          static_cast<size_t>(C),
          static_cast<size_t>(input_image_size));
      input_data = static_cast<ActType*>(transpose_input_buffer.get());
      output_data = static_cast<ActType*>(transpose_output_buffer.get());
    }

    // Threaded implementation of ND convolution is not yet supported, so
    // prepare all im2col transformations here.
    if (col_buffer && kernel_rank > 2) {
      for (int64_t group_id = 0; group_id < group_count; ++group_id) {
        math::Im2col<ActType, StorageOrder::NHWC>()(
            input_data + group_id * group_input_channels,
            group_input_channels,
            C,
            input_shape.GetDims().data(),
            output_shape.GetDims().data(),
            kernel_shape.data(),
            strides.data(),
            dilations.data(),
            pads.data(),
            static_cast<int64_t>(kernel_rank),
            static_cast<ActType*>(col_buffer.get()) + group_id * col_buffer_size,
            X_zero_point_value);
      }
    }

    auto conv_worker = [&](ptrdiff_t batch) {
#if defined(_M_ARM64) || defined(__aarch64__)
      int64_t output_start = batch * GEMM_KERNEL_STRIDE_M;
      int64_t output_count = std::min((int64_t)GEMM_KERNEL_STRIDE_M, output_image_size - output_start);
#else
      auto work = concurrency::ThreadPool::PartitionWork(batch, task_count, static_cast<ptrdiff_t>(output_image_size));
      int64_t output_start = static_cast<int64_t>(work.start);
      int64_t output_count = static_cast<int64_t>(work.end) - work.start;
#endif

      ActType const** worker_indirection_buffer = nullptr;
      if (indirection_buffer) {
        worker_indirection_buffer = static_cast<ActType const**>(indirection_buffer.get()) + output_start * kernel_size;
        math::Im2col<ActType, StorageOrder::NHWC>()(
            input_data,
            C,
            input_shape.GetDims().data(),
            output_shape.GetDims().data(),
            kernel_shape.data(),
            strides.data(),
            dilations.data(),
            pads.data(),
            static_cast<ptrdiff_t>(kernel_rank),
            output_start,
            output_count,
            worker_indirection_buffer,
            padding_data.data());
      }

      auto* worker_output = output_data + output_start * M;

      if (is_symmetric_conv_) {
        MLAS_CONV_SYM_PARAMS conv_params = {};
        if (worker_indirection_buffer) {
          conv_params.InputIndirection = reinterpret_cast<void const**>(worker_indirection_buffer);
        } else {
          conv_params.InputDirect = input_data + output_start * C;
        }
        conv_params.Filter = packed_W_buffer_.get();
        conv_params.Output = worker_output;
        conv_params.InputChannels = static_cast<size_t>(C);
        conv_params.OutputChannels = static_cast<size_t>(M);
        conv_params.OutputCount = static_cast<size_t>(output_count);
        conv_params.KernelSize = static_cast<size_t>(kernel_size);
        conv_params.Bias = column_sums_.data();
        conv_params.Scale = output_scales.data();
        conv_params.PerChannelScale = output_scales.size() > 1;
        conv_params.OutputZeroPoint = Y_zero_point_value;
        conv_params.InputIsSigned = std::is_signed<ActType>::value;

        if (is_depthwise_conv) {
          MlasConvSymDepthwise(conv_params);
        } else {
          MlasConvSym(conv_params);
        }
        return;
      }

      auto* worker_gemm_output = static_cast<int32_t*>(gemm_output_buffer.get()) + output_start * M;

      if (is_depthwise_conv) {
        MlasConvDepthwise(
            reinterpret_cast<const void* const*>(worker_indirection_buffer),
            X_zero_point_value,
            std::is_signed<ActType>::value,
            reinterpret_cast<const void* const*>(reordered_W),
            W_zero_point_value,
            is_W_signed,
            worker_gemm_output,
            static_cast<size_t>(M),
            static_cast<size_t>(output_count),
            static_cast<size_t>(kernel_size));
      } else {
        for (int64_t group_id = 0; group_id < group_count; ++group_id) {
          // Prepare the im2col transformation or use the input buffer directly for
          // pointwise convolutions.
          const auto* group_input_data = input_data + group_id * group_input_channels;
          const uint8_t* AData;
          size_t lda;
          if (col_buffer) {
            auto* worker_col_buffer = static_cast<ActType*>(col_buffer.get()) + output_start * kernel_dim;
            if (kernel_rank == 2) {
              math::Im2col<ActType, StorageOrder::NHWC>()(
                  group_input_data,
                  group_input_channels,
                  C,
                  input_shape[0],
                  input_shape[1],
                  kernel_shape[0],
                  kernel_shape[1],
                  dilations[0],
                  dilations[1],
                  pads[0],
                  pads[1],
                  strides[0],
                  strides[1],
                  output_shape[1],
                  output_start,
                  output_count,
                  worker_col_buffer,
                  X_zero_point_value);
            } else if (kernel_rank == 1) {
              math::Im2col<ActType, StorageOrder::NHWC>()(
                  group_input_data,
                  group_input_channels,
                  C,
                  1,
                  input_shape[0],
                  1,
                  kernel_shape[0],
                  1,
                  dilations[0],
                  0,
                  pads[0],
                  1,
                  strides[0],
                  output_shape[0],
                  output_start,
                  output_count,
                  worker_col_buffer,
                  X_zero_point_value);
            } else {
              // Use the im2col buffer prepared outside the thread, indexed by group.
              worker_col_buffer += group_id * col_buffer_size;
            }
            AData = reinterpret_cast<const uint8_t*>(worker_col_buffer);
            lda = static_cast<size_t>(kernel_dim);
          } else {
            AData = reinterpret_cast<const uint8_t*>(group_input_data + output_start * C);
            lda = static_cast<size_t>(C);
          }

          MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
          gemm_shape.M = static_cast<size_t>(output_count);
          gemm_shape.N = static_cast<size_t>(group_output_channels);
          gemm_shape.K = static_cast<size_t>(kernel_dim);
          gemm_shape.AIsSigned = std::is_signed<ActType>::value;
          gemm_shape.BIsSigned = is_W_signed;

          if (is_symmetric_gemm_) {
            MLAS_SYMM_QGEMM_DATA_PARAMS symm_gemm;
            symm_gemm.A = AData;
            symm_gemm.lda = lda;
            symm_gemm.C = worker_gemm_output + group_id * group_output_channels;
            symm_gemm.ldc = static_cast<size_t>(M);
            symm_gemm.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
            MlasSymmQgemmBatch(gemm_shape, &symm_gemm, 1, nullptr);
          } else {
            MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
            gemm_params.ZeroPointA = static_cast<uint8_t>(X_zero_point_value);
            gemm_params.A = AData;
            gemm_params.lda = lda;
            if (packed_W_buffer_) {
              gemm_params.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
              gemm_params.BIsPacked = true;
            } else {
              gemm_params.B = reordered_W + group_id * group_output_channels,
              gemm_params.ldb = static_cast<size_t>(M);
            }
            gemm_params.ZeroPointB = &W_zero_point_value;
            gemm_params.C = worker_gemm_output + group_id * group_output_channels;
            gemm_params.ldc = static_cast<size_t>(M);

            MlasGemm(gemm_shape, gemm_params, nullptr);
          }
        }
      }

      MlasRequantizeOutput(
          worker_gemm_output,
          static_cast<size_t>(M),
          worker_output,
          static_cast<size_t>(M),
          Bdata,
          output_scales.data(),
          output_scales.size() > 1,
          Y_zero_point_value,
          0,
          0,
          static_cast<size_t>(output_count),
          static_cast<size_t>(M));
    };

    concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, task_count, conv_worker);

    if (!channels_last_) {
      // Transpose the output from channels last (NHWC) to channels first (NCHW).
      MlasTranspose(
          output_data,
          Ydata,
          static_cast<size_t>(output_image_size),
          static_cast<size_t>(M));
    }

    Xdata += X_offset;
    Ydata += Y_offset;
  }

  return Status::OK();
}