void GenConvKernel::genForSingleOutput()

in src/GroupwiseConv.cc [321:408]


void GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleOutput(
    x86::Emitter* a,
    bool isLeft,
    bool isRight,
    bool isTop,
    bool isBottom) {
  // init result regs
  initResultRegs(a);

  // row offset
  if (this->needRowOffset_) {
    a->vpxor(
        rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm());
  }

  bool isWidthMiddle = !isLeft && !isRight;
  bool isHeightMiddle = !isTop && !isBottom;
  int num_rows_advanced = 0;
  for (int r = 0; r < this->R_; ++r) {
    int h_in = r;
    if (isTop) {
      h_in = -this->H_PAD_ + r;
    }
    bool in_image_H = (isTop && !isBottom && h_in >= 0) ||
        (!isTop && isBottom && h_in < (this->R_ - this->H_PAD_)) ||
        (isTop && isBottom && h_in >= 0 &&
         h_in < (this->R_ - 2 * this->H_PAD_)) ||
        isHeightMiddle;
    for (int s = 0; s < this->S_; ++s) {
      int w_in = s;
      if (isLeft) {
        w_in = -this->W_PAD_ + s;
      }
      bool in_image_W = (isLeft && !isRight && w_in >= 0) ||
          (!isLeft && isRight && w_in < (this->S_ - this->W_PAD_)) ||
          (isLeft && isRight && w_in >= 0 &&
           w_in < (this->S_ - 2 * this->W_PAD_)) ||
          isWidthMiddle;
      if (in_image_H && in_image_W) {
        genForSingleFilterPoint(a, r, s, w_in, false);
      } else {
        if (!this->isAZeroPointZero_) {
          genForSingleFilterPoint(a, r, s, w_in, true);
        }
      }
    }
    if (in_image_H) {
      // advance input pointer by one row
      a->imul(
          scratchReg2_,
          W_R_,
          static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t)));
      a->add(in_acts_R_, scratchReg2_);
      ++num_rows_advanced;
    }
  }

  storeResult(a);

  // row offset
  if (this->needRowOffset_) {
    storeOffset(a);
    a->add(
        row_offset_R_, static_cast<asmjit::Imm>(GTogether_ * sizeof(int32_t)));
  }

  // rewind input ptr
  a->imul(
      scratchReg2_,
      W_R_,
      static_cast<asmjit::Imm>(num_rows_advanced * this->C_ * sizeof(uint8_t)));
  a->sub(in_acts_R_, scratchReg2_);

  // advance output pointer
  a->add(out_acts_R_, static_cast<asmjit::Imm>(this->K_ * sizeof(int32_t)));

  // advance input ptr
  if (!isLeft) {
    a->add(
        in_acts_R_,
        static_cast<asmjit::Imm>(this->STRIDE_ * this->C_ * sizeof(uint8_t)));
  } else if (this->STRIDE_ - this->W_PAD_) {
    a->add(
        in_acts_R_,
        static_cast<asmjit::Imm>(
            (this->STRIDE_ - this->W_PAD_) * this->C_ * sizeof(uint8_t)));
  }
}