src/GroupwiseConv.h (256 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <asmjit/asmjit.h>
#include <cpuinfo.h>
#include <cassert>
#include <cstdint>
#include <map>
#include <mutex>
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
#include "./CodeCache.h"
#include "fbgemm/ConvUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/Utils.h"
/*#define FBGEMM_LOG_CODE 1*/
#define GCONV_INST_AVX2_HEADER \
template <inst_set_t ISET = INST_SET> \
typename std::enable_if<ISET == inst_set_t::avx2, void>::type
#define GCONV_INST_AVX512_AND_VNNI_HEADER \
template <inst_set_t ISET = INST_SET> \
typename std::enable_if< \
ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \
void>::type
#define GCONV_INST_DEF_AVX2_HEADER \
template <int SPATIAL_DIM, inst_set_t INST_SET> \
template <inst_set_t ISET> \
typename std::enable_if<ISET == inst_set_t::avx2, void>::type
#define GCONV_INST_DEF_AVX512_AND_VNNI_HEADER \
template <int SPATIAL_DIM, inst_set_t INST_SET> \
template <inst_set_t ISET> \
typename std::enable_if< \
ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \
void>::type
namespace fbgemm {
namespace x86 = asmjit::x86;
template <typename>
struct is_requantization : std::false_type {};
template <
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
typename BIAS_TYPE,
typename outT,
typename inT,
typename nextOPType>
struct is_requantization<
ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>>
: std::true_type {};
using jit_conv_kernel_fp = void (*)(
const uint8_t* in_acts,
int8_t* wghts,
int32_t* out_acts,
int32_t a_zero_pt,
int32_t oh_start,
int32_t oh_end,
int32_t ow,
int32_t* row_offset);
using kernel_sig_t = std::tuple<
bool, /* is A zero point 0 */
bool, /* should row offset be calculated */
bool, /* is top edge included */
bool, /* is bottom edge included */
bool, /* is top bottom edge same? */
bool, /* use paddings on bottom side? */
bool, /* use paddings on right side? */
bool, /* accumulate rowoffsets and output instead of overwrite? */
int, /* groups */
int, /* stride */
int, /* number of input channels per group */
int>; /* number of output channels per group */
// Common code in a base class
template <int SPATIAL_DIM, inst_set_t INST_SET>
class GenConvKernelBase {
public:
GenConvKernelBase(
const conv_param_t<SPATIAL_DIM>& conv_param,
std::int32_t a_zero_point,
bool needRowOffset,
bool isTopEdgeIncluded,
bool isBottomEdgeIncluded,
bool isTopBottomEdgeSame,
bool accum) {
assert(fbgemmOptimizedGConv(conv_param));
isAZeroPointZero_ = a_zero_point == 0;
needRowOffset_ = needRowOffset;
isTopEdgeIncluded_ = isTopEdgeIncluded;
isBottomEdgeIncluded_ = isBottomEdgeIncluded;
isTopBottomEdgeSame_ = isTopBottomEdgeSame;
accum_ = accum;
G_ = conv_param.G;
K_per_G_ = conv_param.OC / conv_param.G;
K_ = conv_param.OC;
C_per_G_ = conv_param.IC / conv_param.G;
C_ = conv_param.IC;
// Strides are assumed to be the same in all directions
STRIDE_ = conv_param.stride[0];
R_ = conv_param.K[0];
S_ = conv_param.K[1];
OH_ = conv_param.OUT_DIM[0];
OW_ = conv_param.OUT_DIM[1];
H_PAD_ = conv_param.pad[0];
W_PAD_ = conv_param.pad[1];
use_bottom_padding_ =
!(STRIDE_ > 1 && conv_param.IN_DIM[SPATIAL_DIM - 2] % 2 == 0);
use_right_padding_ =
!(STRIDE_ > 1 && conv_param.IN_DIM[SPATIAL_DIM - 1] % 2 == 0);
}
~GenConvKernelBase() {}
static std::string getCodeLoggingFile(kernel_sig_t kernel_sig) {
std::ostringstream oss;
oss << "conv";
oss << "_G-" << std::get<8>(kernel_sig);
oss << "_stride-" << std::get<9>(kernel_sig);
oss << "_IC_per_G-" << std::get<10>(kernel_sig);
oss << "_OC_per_G-" << std::get<11>(kernel_sig);
oss << "_isZeroPointZero-" << std::get<0>(kernel_sig);
oss << "_rowoffset-" << std::get<1>(kernel_sig);
oss << "_topEdge-" << std::get<2>(kernel_sig);
oss << "_bottomEdge-" << std::get<3>(kernel_sig);
oss << "_isTopBottomSame-" << std::get<4>(kernel_sig);
oss << "_useBottomPadding-" << std::get<5>(kernel_sig);
oss << "_useRightPadding-" << std::get<6>(kernel_sig);
oss << "_accum-" << std::get<7>(kernel_sig);
if (INST_SET == inst_set_t::avx512) {
oss << "_avx512";
} else if (INST_SET == inst_set_t::avx2) {
oss << "_avx2";
} else {
oss << "_unknown";
}
oss << ".txt";
return oss.str();
}
static asmjit::JitRuntime& runtime() {
static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
// depents on other static
// variables. Required to prevent
// initialization order fiasco
return rt;
}
static std::mutex rtMutex_; ///< Control access to runtime;
static CodeCache<
kernel_sig_t,
jit_conv_kernel_fp>
codeCache_; ///< JIT Code Cache for reuse.
protected:
// current conv parameters
int G_; ///< Number of groups
int K_; ///< Number of output channels
int K_per_G_; ///< Number of output channels per group
int C_; ///< Number of input channels
int STRIDE_; ///< Stride in either direction
int C_per_G_; ///< Number of input channels per group
int R_; ///< Filter/Kernel height
int S_; ///< Filter/Kernel width
int OH_; ///< output height
int OW_; ///< output width
int H_PAD_; ///< Padding for height (top and bottom)
int W_PAD_; ///< Padding for width (left and right)
// Other parameters
bool isAZeroPointZero_;
bool needRowOffset_;
bool isTopEdgeIncluded_;
bool isBottomEdgeIncluded_;
bool isTopBottomEdgeSame_;
bool accum_;
// For 3x3 kernels with pad == 1: If stride is 2 and image height/width are
// even, the right or bottom paddings are not used. This variables is set to
// false if paddings on the left and bottom are not used and kernel generation
// takes care to not generate code with paddings on the right and bottom side.
bool use_bottom_padding_;
bool use_right_padding_;
};
// Generic class
template <int SPATIAL_DIM, inst_set_t INST_SET>
class FBGEMM_API GenConvKernel
: public GenConvKernelBase<SPATIAL_DIM, INST_SET> {
typedef typename simd_info<INST_SET>::vec_reg_t vec_reg_t;
public:
GenConvKernel(
const conv_param_t<SPATIAL_DIM>& conv_param,
std::int32_t a_zero_point,
bool needRowoffset,
bool isTopEdgeIncluded,
bool isBottomEdgeIncluded,
bool isTopBottomEdgeSame,
bool accum)
: GenConvKernelBase<SPATIAL_DIM, INST_SET>(
conv_param,
a_zero_point,
needRowoffset,
isTopEdgeIncluded,
isBottomEdgeIncluded,
isTopBottomEdgeSame,
accum) {
constexpr int SIMD_WIDTH = simd_info<INST_SET>::WIDTH_BYTES;
GTogether_ = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>::
numOfGroupsTogether(conv_param);
kLoopIters_ = this->K_per_G_ * this->C_per_G_ / SIMD_WIDTH;
// y/zmm0-8 are used for holding weights
zeroPTReg_V_ = vec_reg_t(10);
tmpReg1_V_ = vec_reg_t(11);
stPermReg_V_ = vec_reg_t(12);
actReg_V_ = vec_reg_t(13);
oneReg16Bit_V_ = vec_reg_t(15);
rowOffsetReg_V_ = vec_reg_t(14);
}
jit_conv_kernel_fp getOrCreate();
GCONV_INST_AVX2_HEADER genForLoadingWeights(x86::Emitter* a);
GCONV_INST_AVX512_AND_VNNI_HEADER genForLoadingWeights(x86::Emitter* a);
GCONV_INST_AVX2_HEADER genConstForPermutations(x86::Emitter* a);
GCONV_INST_AVX512_AND_VNNI_HEADER genConstForPermutations(x86::Emitter* a);
GCONV_INST_AVX2_HEADER genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg);
GCONV_INST_AVX512_AND_VNNI_HEADER genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg);
GCONV_INST_AVX2_HEADER storeResult(x86::Emitter* a);
GCONV_INST_AVX512_AND_VNNI_HEADER storeResult(x86::Emitter* a);
GCONV_INST_AVX2_HEADER storeOffset(x86::Emitter* a);
GCONV_INST_AVX512_AND_VNNI_HEADER storeOffset(x86::Emitter* a);
void genForTopOrBottomEdge(x86::Emitter* a, bool isTop, bool isBottom);
void initResultRegs(x86::Emitter* a);
void genCoreInsts(x86::Emitter* a);
void genForSingleOutput(
x86::Emitter* a,
bool isLeft,
bool isRight,
bool isTop,
bool isBottom);
private:
int GTogether_;
// The number of iterations needed for K dim.
// e.g., C_per_G_ = K_per_G_ = 8, we have to iterate
// twice on K dim because 4 (from K dim) * 8 ( from C dim)
// fill the full avx2 vector width.
int kLoopIters_;
asmjit::FuncDetail func_;
asmjit::FuncFrame frame_;
vec_reg_t zeroPTReg_V_;
vec_reg_t tmpReg1_V_;
vec_reg_t stPermReg_V_;
vec_reg_t actReg_V_;
vec_reg_t resultReg_V_;
vec_reg_t oneReg8Bit_V_;
vec_reg_t oneReg16Bit_V_;
vec_reg_t rowOffsetReg_V_;
// arguments to the function created
x86::Gp in_acts_R_;
x86::Gp wghts_R_;
x86::Gp out_acts_R_;
x86::Gp a_zero_pt_R_;
x86::Gp H_R_;
x86::Gp H_start_R_;
x86::Gp H_end_R_;
x86::Gp W_R_;
x86::Gp row_offset_R_;
// Used registers
x86::Gp loopR1_;
x86::Gp loopR2_;
x86::Gp scratchReg1_;
x86::Gp scratchReg2_;
};
template <int SPATIAL_DIM, inst_set_t INST_SET>
std::mutex GenConvKernelBase<SPATIAL_DIM, INST_SET>::rtMutex_;
template <int SPATIAL_DIM, inst_set_t INST_SET>
CodeCache<kernel_sig_t, jit_conv_kernel_fp>
GenConvKernelBase<SPATIAL_DIM, INST_SET>::codeCache_;
} // namespace fbgemm