in src/GroupwiseConv.cc [321:408]
void GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleOutput(
x86::Emitter* a,
bool isLeft,
bool isRight,
bool isTop,
bool isBottom) {
// init result regs
initResultRegs(a);
// row offset
if (this->needRowOffset_) {
a->vpxor(
rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm());
}
bool isWidthMiddle = !isLeft && !isRight;
bool isHeightMiddle = !isTop && !isBottom;
int num_rows_advanced = 0;
for (int r = 0; r < this->R_; ++r) {
int h_in = r;
if (isTop) {
h_in = -this->H_PAD_ + r;
}
bool in_image_H = (isTop && !isBottom && h_in >= 0) ||
(!isTop && isBottom && h_in < (this->R_ - this->H_PAD_)) ||
(isTop && isBottom && h_in >= 0 &&
h_in < (this->R_ - 2 * this->H_PAD_)) ||
isHeightMiddle;
for (int s = 0; s < this->S_; ++s) {
int w_in = s;
if (isLeft) {
w_in = -this->W_PAD_ + s;
}
bool in_image_W = (isLeft && !isRight && w_in >= 0) ||
(!isLeft && isRight && w_in < (this->S_ - this->W_PAD_)) ||
(isLeft && isRight && w_in >= 0 &&
w_in < (this->S_ - 2 * this->W_PAD_)) ||
isWidthMiddle;
if (in_image_H && in_image_W) {
genForSingleFilterPoint(a, r, s, w_in, false);
} else {
if (!this->isAZeroPointZero_) {
genForSingleFilterPoint(a, r, s, w_in, true);
}
}
}
if (in_image_H) {
// advance input pointer by one row
a->imul(
scratchReg2_,
W_R_,
static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t)));
a->add(in_acts_R_, scratchReg2_);
++num_rows_advanced;
}
}
storeResult(a);
// row offset
if (this->needRowOffset_) {
storeOffset(a);
a->add(
row_offset_R_, static_cast<asmjit::Imm>(GTogether_ * sizeof(int32_t)));
}
// rewind input ptr
a->imul(
scratchReg2_,
W_R_,
static_cast<asmjit::Imm>(num_rows_advanced * this->C_ * sizeof(uint8_t)));
a->sub(in_acts_R_, scratchReg2_);
// advance output pointer
a->add(out_acts_R_, static_cast<asmjit::Imm>(this->K_ * sizeof(int32_t)));
// advance input ptr
if (!isLeft) {
a->add(
in_acts_R_,
static_cast<asmjit::Imm>(this->STRIDE_ * this->C_ * sizeof(uint8_t)));
} else if (this->STRIDE_ - this->W_PAD_) {
a->add(
in_acts_R_,
static_cast<asmjit::Imm>(
(this->STRIDE_ - this->W_PAD_) * this->C_ * sizeof(uint8_t)));
}
}