bool isValidBlockingFactor()

in include/fbgemm/Utils.h [288:338]


bool isValidBlockingFactor(BlockingFactors* param) {
  constexpr bool is_32bit = std::is_same<accT, int32_t>::value;
  constexpr bool is_16bit = std::is_same<accT, int16_t>::value;
  static const auto iset = fbgemmInstructionSet();

  if (is_32bit) {
    if (param->ROW_INTERLEAVE != 4)
      return false;

    if (isZmm(iset)) {
      if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
        return false;
    } else if (isYmm(iset)) {
      if (param->NR_MIN != 8 || param->NR % param->NR_MIN)
        return false;
    }
  } else if (is_16bit) {
    if (param->ROW_INTERLEAVE != 2)
      return false;

    if (isZmm(iset)) {
      if (param->NR_MIN != 32 || param->NR % param->NR_MIN)
        return false;
    } else if (isYmm(iset)) {
      if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
        return false;
    }
  }

  if (param->MCB % param->MR)
    return false;
  if (param->NCB % param->NR)
    return false;
  if (isZmm(iset)) {
    if (is_32bit) {
      // Zmm register usage for C
      if (param->MR * (param->NR / param->NR_MIN) > 28)
        return false;
    } else if (is_16bit) {
      // Zmm register usage for C + one row for loading B
      if ((param->MR * (param->NR / param->NR_MIN) +
           (param->NR / param->NR_MIN)) > 28)
        return false;
    }

  } else if (isYmm(iset)) {
    if (param->MR * (param->NR / param->NR_MIN) > 12)
      return false;
  }
  return true;
}