fbgemm_gpu/codegen/embedding_forward_split_template.cu (450 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
{#
// @lint-ignore LINTIGNORE
// @lint-ignore-every CLANGFORMAT
// clang-format off
// Note: clang-format off doesn't work with this templaterized code,
// so we need to keep lint-ignore-every.
// See https://fburl.com/dw9ljh4h
#}
{% set wdesc = "weighted" if weighted else "unweighted" %}
#include "codegen/embedding_forward_template_helpers.cuh"
{% if not dense %}
constexpr int32_t kCacheLocationMissing = -1;
{% endif %}
constexpr size_t kForwardMaxThreads = 512;
using Tensor = at::Tensor;
using namespace fbgemm_gpu;
{% for nobag in [True, False] %}
{% if not nobag or not weighted %}
template <
typename emb_t,
typename cache_t,
{% if not dense %}
typename output_t,
{% endif %}
typename index_t
{% if not nobag %}
,size_t kMaxVecsPerThread
{% endif %}
>
__launch_bounds__(kForwardMaxThreads)
__global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel(
const at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights,
{% if not dense %}
const at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights,
const at::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits>
lxu_cache_weights,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
{% endif %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
{% if not nobag %}
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
{% else %}
int64_t D,
{% endif %}
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
{% if not nobag %}
int64_t pooling_mode,
{% endif %}
{% if weighted %}
at::PackedTensorAccessor32<at::acc_type<cache_t, true>, 1, at::RestrictPtrTraits>
indice_weights,
{% endif %}
{% if not dense %}
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations,
at::PackedTensorAccessor32<output_t, 2, at::RestrictPtrTraits>
output // [B][total_D],
{% else %}
at::PackedTensorAccessor32<at::acc_type<cache_t,true>, 2, at::RestrictPtrTraits>
output // [B][total_D],
{% endif %}
) {
{% if not nobag %}
int32_t B = output.size(0);
int32_t T = D_offsets.size(0) - 1;
{% else %}
int32_t T = weights_offsets.size(0);
int32_t B = (offsets.size(0) - 1) / T;
{% endif %}
int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
int32_t t = b_t / B;
int32_t b = b_t % B;
if (b_t >= B * T) {
return;
}
int64_t weights_offset = weights_offsets[t];
{% if not nobag %}
int32_t D_start = D_offsets[t];
int32_t D_end = D_offsets[t + 1];
int32_t D = D_end - D_start;
{% endif %}
index_t indices_start = offsets[t * B + b];
index_t indices_end = offsets[t * B + b + 1];
int32_t L = indices_end - indices_start;
const emb_t* __restrict__ weights;
{% if not dense %}
const auto placement = static_cast<PlacementType>(weights_placements[t]);
if (placement == PlacementType::DEVICE) {
weights = &dev_weights[weights_offset];
} else {
weights = &uvm_weights[weights_offset];
}
{% else %}
weights = &dev_weights[weights_offset];
{% endif %}
int32_t D_emb = D;
if (std::is_same<emb_t, uint8_t>::value) {
D_emb += kINT8QparamsBytes;
}
{% if not nobag %}
Vec4T<cache_t> accumulators[kMaxVecsPerThread];
{% endif %}
for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
int32_t l = l_start + threadIdx.x;
int64_t idx = l < L ? indices[indices_start + l] : 0;
{% if not dense %}
int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0;
{% endif %}
{% if weighted %}
at::acc_type<cache_t, true> idx_weight = l < L ? indice_weights[indices_start + l] : 0;
{% endif %}
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
int64_t idx_j = shfl_sync(idx, j);
{% if nobag %}
int64_t output_j = indices_start + l_start + j;
{% endif %}
{% if not dense %}
int32_t cache_idx_j = shfl_sync(cache_idx, j);
{% endif %}
{% if weighted %}
at::acc_type<cache_t, true> idx_weight_j = shfl_sync(idx_weight, j);
{% endif %}
{% if not dense %}
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
D,
nullptr);
float2 qparams_cache; // assume cache is fp16/fp32 which doesn't require qparams
{% endif %}
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
float2 qparams_emb;
if (std::is_same<emb_t, uint8_t>::value) {
qparams_emb = weight_row_emb.load_qparams();
}
{% if not nobag %}
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
{% if not dense %}
if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
{% if weighted %}
accumulators[i].fma_(weight, idx_weight_j);
{% else %}
accumulators[i].add_(weight);
{% endif %}
} else {
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
{% if weighted %}
accumulators[i].fma_(weight, idx_weight_j);
{% else %}
accumulators[i].add_(weight);
{% endif %}
}
{% else %}
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
{% if weighted %}
accumulators[i].fma_(weight, idx_weight_j);
{% else %}
accumulators[i].add_(weight);
{% endif %}
{% endif %}
}
{% else %}
for (int32_t i = 0; i < D; i+=4 * kWarpSize) {
int32_t d = i + threadIdx.x * 4;
if (d < D) {
{% if not dense %}
if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
weight.store(&output[output_j][d]);
} else {
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
}
{% else %}
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
{% endif %}
}
}
{% endif %}
}
}
{% if not nobag %}
{% if not dense %}
if (!std::is_same<output_t, uint8_t>::value) {
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN && L != 0) {
accumulators[i].mul_(1.0 / L);
}
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
accumulators[i].store(&output[b][D_start + d]);
}
} else {
// apply per feature row-wise int8
float thread_local_min = std::numeric_limits<float>::max();
float thread_local_max = std::numeric_limits<float>::lowest();
float2 qparams;
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN && L != 0) {
accumulators[i].mul_(1.0 / L);
}
thread_local_max = max(thread_local_max, vec4_max(accumulators[i]));
thread_local_min = min(thread_local_max, vec4_min(accumulators[i]));
}
qparams = warp_find_qparams(thread_local_min, thread_local_max);
int output_D_start = D_start + t * 8;
int output_D_end = output_D_start + D;
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
nearest_rounding_vector<output_t, cache_t>(&output[b][output_D_start + d], accumulators[i], qparams);
}
if (threadIdx.x == 0) {
store_qparams_to_row(&output[b][output_D_end], qparams);
}
}
{% else %}
// no pooled embedding quantization fusion for dense embeddings
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN && L != 0) {
accumulators[i].mul_(1.0 / L);
}
accumulators[i].store(&output[b][D_start + d]);
}
{% endif %}
{% endif %}
}
Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda(
Tensor dev_weights,
{% if not dense %}
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
{% endif %}
Tensor weights_offsets,
{% if not nobag %}
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
{% else %}
int64_t D,
{% endif %}
Tensor indices,
Tensor offsets,
{% if not nobag %}
int64_t pooling_mode,
{% endif %}
{% if weighted %}
Tensor indice_weights,
{% endif %}
{% if not dense %}
Tensor lxu_cache_locations,
{% endif %}
{% if not dense and not nobag %}
int64_t output_dtype,
{% endif %}
int64_t unused
) {
TENSOR_ON_CUDA_GPU(dev_weights);
{% if not dense %}
TENSOR_ON_CUDA_GPU(uvm_weights);
TENSOR_ON_CUDA_GPU(lxu_cache_weights);
TENSOR_ON_CUDA_GPU(weights_placements);
{% endif %}
TENSOR_ON_CUDA_GPU(weights_offsets);
{% if not nobag %}
TENSOR_ON_CUDA_GPU(D_offsets);
{% endif %}
TENSOR_ON_CUDA_GPU(indices);
TENSOR_ON_CUDA_GPU(offsets);
{% if weighted %}
TENSOR_ON_CUDA_GPU(indice_weights);
{% endif %}
{% if not dense %}
TENSOR_ON_CUDA_GPU(lxu_cache_locations);
{% endif %}
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dev_weights.get_device());
{% if not nobag %}
int32_t T = D_offsets.numel() - 1;
{% else %}
int32_t total_L = indices.numel();
int32_t T = weights_offsets.numel();
{% endif %}
TORCH_CHECK(T > 0);
// offsets = [B x T + 1]
int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B >= 0);
{% if not nobag %}
TORCH_CHECK(total_D > 0);
TORCH_CHECK(total_D % 4 == 0);
TORCH_CHECK(max_D <= {{ max_embedding_dim }});
{% else %}
TORCH_CHECK(D > 0);
TORCH_CHECK(D % 4 == 0);
{% endif %}
{% if nobag %}
Tensor output = at::empty({total_L, D}, dev_weights.options().dtype(at::kFloat));
{% else %}
Tensor output;
{% if dense %}
if (dev_weights.type().scalarType() == at::kHalf || dev_weights.type().scalarType() == at::kByte) {
output = at::empty({B, total_D}, dev_weights.options().dtype(at::kFloat));
} else {
output = at::empty({B, total_D}, dev_weights.options());
}
{% else %}
SparseType o_dtype = static_cast<SparseType>(output_dtype);
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 ||
o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8);
if (o_dtype == SparseType::FP32) {
output = at::empty({B, total_D}, dev_weights.options().dtype(at::kFloat));
} else if (o_dtype == SparseType::FP16) {
output = at::empty({B, total_D}, dev_weights.options().dtype(at::kHalf));
} else if (o_dtype == SparseType::BF16) {
output = at::empty({B, total_D}, dev_weights.options().dtype(at::kBFloat16));
} else if (o_dtype == SparseType::INT8) {
output = at::empty({B, int64_t(total_D + T * kINT8QparamsBytes)}, dev_weights.options().dtype(at::kByte));
}
{% endif %}
{% endif %}
if (B == 0) {
return output;
}
{% if not dense %}
DISPATCH_EMB_CACHE_OUTPUT_TYPES(
{% else %}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
{% endif %}
dev_weights.type(),
{% if not dense %}
lxu_cache_weights.type(),
output.type(),
{% endif %}
"batched_embedding{{ "_nobag" if nobag else "" }}_forward_kernel_2", [&] {
{% if not nobag %}
{% for kMaxVecsPerThread in range(1, max_embedding_dim // 128 + 1) %}
if (max_D <= {{ 128 * kMaxVecsPerThread }}) {
{% if not dense %}
split_embedding_codegen_forward_{{ wdesc }}_kernel<emb_t, cache_t, output_t, int64_t, {{ kMaxVecsPerThread }}><<<
{% else %}
dense_embedding_codegen_forward_{{ wdesc }}_kernel<scalar_t, scalar_t, int64_t, {{ kMaxVecsPerThread }}><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
pooling_mode,
{% if weighted %}
indice_weights.packed_accessor32<at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>, 1, at::RestrictPtrTraits>(),
{% endif %}
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}
return;
}
{% endfor %}
{% else %}
{% if not dense %}
split_embedding_nobag_codegen_forward_unweighted_kernel<emb_t, cache_t, output_t, int64_t><<<
{% else %}
dense_embedding_nobag_codegen_forward_unweighted_kernel<scalar_t, scalar_t, int64_t><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D,
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}
return;
{% endif %}
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}
{% endif %}
{% endfor %}
// clang-format on