in maga_transformer/cpp/cutlass/cutlass_kernels/cutlass_preprocessors.cc [223:355]
void subbyte_transpose_impl(
int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, const std::vector<size_t>& shape)
{
const int bits_per_elt = get_bits_in_quant_type(quant_type);
RTP_LLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const size_t col_bytes = num_cols * bits_per_elt / 8;
const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
const uint8_t* input_byte_ptr = reinterpret_cast<const uint8_t*>(quantized_tensor);
uint8_t* output_byte_ptr = reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, "");
static constexpr int ELTS_PER_BYTE = quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2;
static constexpr int M_TILE_L1 = 64;
static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
uint8_t cache_buf[M_TILE_L1][N_TILE_L1];
static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1);
// We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples
// of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it
// allows GCC to emit vector instructions.
RTP_LLM_CHECK_WITH_INFO(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH),
rtp_llm::fmtstr("Number of bytes for rows and cols must be a multiple of %d. However, num_rows_bytes = %ld and "
"num_col_bytes = %ld.",
VECTOR_WIDTH, col_bytes_trans, col_bytes));
for (size_t expert = 0; expert < num_experts; ++expert)
{
const size_t matrix_offset = expert * num_rows * col_bytes;
for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1)
{
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1)
{
const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
for (int ii = 0; ii < M_TILE_L1; ++ii)
{
const int row = row_tile_start + ii;
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
{
const int col = col_tile_start_byte + jj;
const size_t logical_src_offset = matrix_offset + row * col_bytes + col;
if (row < row_limit && col < col_limit)
{
for (int v = 0; v < VECTOR_WIDTH; ++v)
{
cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v];
}
}
}
}
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
{
for (int ii = 0; ii < M_TILE_L1; ++ii)
{
for (int jj = ii + 1; jj < N_TILE_L1; ++jj)
{
std::swap(cache_buf[ii][jj], cache_buf[jj][ii]);
}
}
}
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
{
for (int ii = 0; ii < M_TILE_L1; ++ii)
{
// Using M_TILE_L1 here is deliberate since we assume that the cache tile
// is square in the number of elements (not necessarily the number of bytes).
for (int jj = ii + 1; jj < M_TILE_L1; ++jj)
{
const int ii_byte = ii / ELTS_PER_BYTE;
const int ii_bit_offset = ii % ELTS_PER_BYTE;
const int jj_byte = jj / ELTS_PER_BYTE;
const int jj_bit_offset = jj % ELTS_PER_BYTE;
uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset));
cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset));
cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset));
cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset));
}
}
}
else
{
RTP_LLM_FAIL("Unsupported quantization type.");
}
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols);
const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
for (int ii = 0; ii < M_TILE_L1; ++ii)
{
const int row = row_tile_start_trans + ii;
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
{
const int col = col_tile_start_byte_trans + jj;
const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col;
if (row < row_limit_trans && col < col_limit_trans)
{
for (int v = 0; v < VECTOR_WIDTH; ++v)
{
output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v];
}
}
}
}
}
}
}
}