maga_transformer/cpp/kernels/rotary_position_embedding.h (720 lines of code) (raw):
#pragma once
#include "decoder_masked_multihead_attention_utils.h"
#include "maga_transformer/cpp/utils/RopeConfig.h"
#include "maga_transformer/cpp/cuda/cuda_type_utils.cuh"
#if USING_CUDA
#include "maga_transformer/cpp/cuda/cuda_utils.h"
#endif
#if USING_ROCM
#include "maga_transformer/cpp/rocm/hip_utils.h"
#endif
#include <type_traits>
namespace rtp_llm {
template<typename scalar_t, typename vector_t>
struct vector_size {};
template<>
struct vector_size<half, uint32_t> {
static constexpr int size = 2;
};
template<>
struct vector_size<float, float> {
static constexpr int size = 1;
};
template<>
struct vector_size<float, float2> {
static constexpr int size = 2;
};
template<>
struct vector_size<float, float4> {
static constexpr int size = 4;
};
template<>
struct vector_size<uint16_t, uint32_t> {
static constexpr int size = 2;
};
template<>
struct vector_size<uint16_t, uint2> {
static constexpr int size = 4;
};
template<>
struct vector_size<uint16_t, uint4> {
static constexpr int size = 8;
};
#ifdef ENABLE_BF16
template<>
struct vector_size<__nv_bfloat16, __nv_bfloat162> {
static constexpr int size = 2;
};
template<>
struct vector_size<__nv_bfloat16, bf16_4_t> {
static constexpr int size = 4;
};
template<>
struct vector_size<__nv_bfloat16, bf16_8_t> {
static constexpr int size = 8;
};
#endif
template<typename scalar_t, typename vector_t>
struct is_alignment {};
template<>
struct is_alignment<float, float> {
static constexpr bool value = true;
};
template<>
struct is_alignment<float, float2> {
static constexpr bool value = true;
};
template<>
struct is_alignment<float, float4> {
static constexpr bool value = true;
};
template<>
struct is_alignment<uint16_t, uint32_t> {
static constexpr bool value = false;
};
template<>
struct is_alignment<uint16_t, uint2> {
static constexpr bool value = false;
};
template<>
struct is_alignment<uint16_t, uint4> {
static constexpr bool value = false;
};
template<>
struct is_alignment<half, uint32_t> {
static constexpr bool value = false;
};
#ifdef ENABLE_BF16
template<>
struct is_alignment<__nv_bfloat16, __nv_bfloat162> {
static constexpr bool value = true;
};
template<>
struct is_alignment<__nv_bfloat16, bf16_4_t> {
static constexpr bool value = false;
};
template<>
struct is_alignment<__nv_bfloat16, bf16_8_t> {
static constexpr bool value = false;
};
#endif
template<typename scalar_t, typename vector_t>
struct assign {
static __device__ __inline__ void read(vector_t& vec, scalar_t& x){};
static __device__ __inline__ void read2(vector_t& vec, scalar_t& x, scalar_t& y) {
if constexpr (is_alignment<scalar_t, vector_t>::value) {
vec.x = x;
vec.y = y;
} else {
union {
vector_t r;
scalar_t t[2];
} result;
result.t[0] = x;
result.t[1] = y;
vec = result.r;
}
};
static __device__ __inline__ void read4(vector_t& vec, scalar_t& x, scalar_t& y, scalar_t& z, scalar_t& w) {
if constexpr (is_alignment<scalar_t, vector_t>::value) {
vec.x = x;
vec.y = y;
vec.z = z;
vec.w = w;
} else {
union {
vector_t r;
scalar_t t[4];
} result;
result.t[0] = x;
result.t[1] = y;
result.t[2] = z;
result.t[3] = w;
vec = result.r;
}
};
static __device__ __inline__ void read8(vector_t& vec,
scalar_t& x,
scalar_t& y,
scalar_t& z,
scalar_t& w,
scalar_t& a,
scalar_t& b,
scalar_t& c,
scalar_t& d) {
if constexpr (is_alignment<scalar_t, vector_t>::value) {
return;
} else {
union {
vector_t r;
scalar_t t[8];
} result;
result.t[0] = x;
result.t[1] = y;
result.t[2] = z;
result.t[3] = w;
result.t[4] = a;
result.t[5] = b;
result.t[6] = c;
result.t[7] = d;
vec = result.r;
}
};
static __device__ __inline__ void write(vector_t& vec, scalar_t& x) {};
static __device__ __inline__ void write2(vector_t& vec, scalar_t& x, scalar_t& y) {
if constexpr (is_alignment<scalar_t, vector_t>::value) {
x = vec.x;
y = vec.y;
} else {
union {
vector_t r;
scalar_t t[2];
} result;
result.r = vec;
x = result.t[0];
y = result.t[1];
}
};
static __device__ __inline__ void write4(vector_t& vec, scalar_t& x, scalar_t& y, scalar_t& z, scalar_t& w) {
if constexpr (is_alignment<scalar_t, vector_t>::value) {
x = vec.x;
y = vec.y;
z = vec.z;
w = vec.w;
} else {
union {
vector_t r;
scalar_t t[4];
} result;
result.r = vec;
x = result.t[0];
y = result.t[1];
z = result.t[2];
w = result.t[3];
}
};
static __device__ __inline__ void write8(vector_t& vec,
scalar_t& x,
scalar_t& y,
scalar_t& z,
scalar_t& w,
scalar_t& a,
scalar_t& b,
scalar_t& c,
scalar_t& d) {
if constexpr (is_alignment<scalar_t, vector_t>::value) {
return;
} else {
union {
vector_t r;
scalar_t t[8];
} result;
result.r = vec;
x = result.t[0];
y = result.t[1];
z = result.t[2];
w = result.t[3];
a = result.t[4];
b = result.t[5];
c = result.t[6];
d = result.t[7];
}
};
};
template<typename vector_t, typename scalar_t>
__device__ __inline__ void RotaryHalfRead(vector_t& vec, scalar_t* smem, const int idx, int dim) {
constexpr int size = vector_size<scalar_t, vector_t>::size;
static_assert(size == 1 || size == 2 || size == 4 || size == 8, "vector size is not valid");
if constexpr (size == 2) {
assign<scalar_t, vector_t>::read2(vec, smem[idx], smem[idx + dim]);
} else if constexpr (size == 4) {
assign<scalar_t, vector_t>::read4(
vec, smem[idx * 2], smem[idx * 2 + dim], smem[idx * 2 + 1], smem[idx * 2 + 1 + dim]);
} else if constexpr (size == 8) {
assign<scalar_t, vector_t>::read8(vec,
smem[idx * 4],
smem[idx * 4 + dim],
smem[idx * 4 + 1],
smem[idx * 4 + 1 + dim],
smem[idx * 4 + 2],
smem[idx * 4 + 2 + dim],
smem[idx * 4 + 3],
smem[idx * 4 + 3 + dim]);
}
}
template<typename vector_t, typename scalar_t>
__device__ __inline__ void RotaryHalfWrite(vector_t& vec, scalar_t* smem, const int idx, int dim) {
constexpr int size = vector_size<scalar_t, vector_t>::size;
static_assert(size == 1 || size == 2 || size == 4 || size == 8, "vector size is not valid");
if constexpr (size == 2) {
assign<scalar_t, vector_t>::write2(vec, smem[idx], smem[idx + dim]);
} else if constexpr (size == 4) {
assign<scalar_t, vector_t>::write4(
vec, smem[idx * 2], smem[idx * 2 + dim], smem[idx * 2 + 1], smem[idx * 2 + 1 + dim]);
} else if constexpr (size == 8) {
assign<scalar_t, vector_t>::write8(vec,
smem[idx * 4],
smem[idx * 4 + dim],
smem[idx * 4 + 1],
smem[idx * 4 + 1 + dim],
smem[idx * 4 + 2],
smem[idx * 4 + 2 + dim],
smem[idx * 4 + 3],
smem[idx * 4 + 3 + dim]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
__device__ __inline__ float rope_inv_freq(const int zid,
const int rot_embed_dim,
const float t_step,
const float base)
{
return (t_step / pow(base, zid / (float)rot_embed_dim));
}
template<typename RopeInit>
__device__ __inline__ float2 rotary_embedding_coefficient(
const int zid,
const int rot_embed_dim,
const float t_step,
const float base,
const RopeInit &rope_init)
{
float inv_freq = rope_inv_freq(zid, rot_embed_dim, t_step, base);
inv_freq = rope_init(inv_freq, zid);
float sin_cos_scale = rope_init.sin_cos_scale();
#if USING_CUDA
float sin_i, cos_i;
#endif
#if USING_ROCM
double sin_i, cos_i;
#endif
sincos(inv_freq, &sin_i, &cos_i);
return {sin_cos_scale * cos_i, sin_cos_scale * sin_i};
}
struct DefaultRope {
__device__ __inline__ float operator()(float inv_freq, int zid) const {
return inv_freq;
}
__device__ __inline__ float sin_cos_scale() const {
return 1.0;
}
};
struct LinearScaleRope {
float scale = 1.0;
__device__ __inline__ float operator()(float inv_freq, int zid) const {
return inv_freq / scale;
}
__device__ __inline__ float sin_cos_scale() const {
return 1.0;
}
};
struct YarnRope {
int dim;
int base;
int max_pos;
float beta_slow;
float beta_fast;
float scaling_factor;
float extrapolation_factor;
float mscale;
static __device__ __inline__ float find_correction_dim(const int num_rotations, const int dim, const int base,
const int max_position_embeddings=2048) {
float pi = 3.141592654f;
float t0 = dim * logf((float)max_position_embeddings / (num_rotations * 2 * pi));
float t1 = 2 * logf((float)base);
return (t0 / t1);
}
static __device__ __inline__ float2 find_correction_range(float low_rot, float high_rot, int dim, int base,
int max_position_embeddings)
{
float2 low_high;
int low = floor(find_correction_dim(low_rot, dim, base, max_position_embeddings));
int high = ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings));
low_high.x = max(low, 0);
low_high.y = min(high, dim-1);
return low_high;
}
static __device__ __inline__ float linear_ramp_mask(float min_, float max_, int tidx) {
if (min_ == max_) {
max_ += 0.001;
}
float linear = (tidx / 2 - min_) / (max_ - min_);
return std::min(1.f, std::max(0.f, linear));
}
__device__ __inline__ float operator()(float inv_freq, int zid) const {
float2 low_high = find_correction_range(beta_fast, beta_slow, dim, base, max_pos);
const float inv_freq_e = inv_freq;
const float inv_freq_i = inv_freq_e / scaling_factor;
const float mask = (1 - linear_ramp_mask(low_high.x, low_high.y, zid)) * extrapolation_factor;
return inv_freq_i * (1 - mask) + inv_freq_e * mask;
}
__device__ __inline__ float sin_cos_scale() const {
return mscale;
}
};
struct Llama3Rope {
float low_freq_factor;
float high_freq_factor;
float factor;
int old_context_len;
__device__ __inline__ float operator()(float inv_freq, int zid) const {
const float pi = 3.141592654f;
const float wavelen = 2 * pi / inv_freq;
const float low_freq_wavelen = old_context_len / low_freq_factor;
const float high_freq_wavelen = old_context_len / high_freq_factor;
if (wavelen < high_freq_wavelen) {
return inv_freq;
} else if (wavelen > low_freq_wavelen) {
return inv_freq / factor;
} else {
const float smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor);
return (1 - smooth) * inv_freq / factor + smooth * inv_freq;
}
}
__device__ __inline__ float sin_cos_scale() const {
return 1.0;
}
};
__device__ __inline__ float2 rotary_embedding_transform(const float2 v, const float2 coef) {
float2 rot_v;
rot_v.x = coef.x * v.x - coef.y * v.y;
rot_v.y = coef.x * v.y + coef.y * v.x;
return rot_v;
}
__device__ __inline__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) {
float2 fv = half2_to_float2(v);
float2 rot_fv = rotary_embedding_transform(fv, coef);
return float2_to_half2(rot_fv);
}
/**
* Rotary position embedding
* Reference: https://arxiv.org/abs/2309.00071
*
* Decoder:
* F(x, pos) = dot(R(pos, ReIndex(x)), Re(x)) + dot(R'(pos, IeIndex(x)), Ie(x))
* R(pos, i) = [cos(pos * Base(i)), -sin(pos * Base(i)]
* R'(pos, i) = [sin(pos * Base(i)), cos(pos * Base()]
* Base(i) = 1 / (base^(2*i / dim))
*
*/
#ifdef ENABLE_BF16
__device__ __inline__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) {
float2 fv = bf1622float2(v);
float2 rot_fv = rotary_embedding_transform(fv, coef);
return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
}
#endif
template<typename RopeInit>
__device__ __inline__ void apply_rotary_embedding(float2& q,
int tid,
int rot_embed_dim,
int t_step,
float base,
const RopeInit &rope_init)
{
if (2 * tid >= rot_embed_dim) {
return;
}
float2 coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base, rope_init);
q = rotary_embedding_transform(q, coef);
}
template<typename RopeInit>
__device__ __inline__ void apply_rotary_embedding(float4& q,
int tid,
int rot_embed_dim,
int t_step,
float base,
const RopeInit &rope_init)
{
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
float2 coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base, rope_init);
float2 coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base, rope_init);
q_.x = rotary_embedding_transform(q_.x, coef0);
q_.y = rotary_embedding_transform(q_.y, coef1);
}
template<typename RopeInit>
__device__ __inline__ void apply_rotary_embedding(uint32_t& q,
int tid,
int rot_embed_dim,
int t_step,
float base,
const RopeInit &rope_init)
{
if (2 * tid >= rot_embed_dim) {
return;
}
float2 coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base, rope_init);
q = rotary_embedding_transform(q, coef);
}
template<typename RopeInit>
__device__ __inline__ void apply_rotary_embedding(uint2& q,
int tid,
int rot_embed_dim,
int t_step,
float base,
const RopeInit &rope_init)
{
if (4 * tid >= rot_embed_dim) {
return;
}
float2 coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base, rope_init);
float2 coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base, rope_init);
q.x = rotary_embedding_transform(q.x, coef0);
q.y = rotary_embedding_transform(q.y, coef1);
}
template<typename RopeInit>
__device__ __inline__ void apply_rotary_embedding(uint4& q,
int tid,
int rot_embed_dim,
int t_step,
float base,
const RopeInit &rope_init)
{
if (8 * tid >= rot_embed_dim) {
return;
}
float2 coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base, rope_init);
float2 coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base, rope_init);
float2 coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base, rope_init);
float2 coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base, rope_init);
q.x = rotary_embedding_transform(q.x, coef0);
q.y = rotary_embedding_transform(q.y, coef1);
q.z = rotary_embedding_transform(q.z, coef2);
q.w = rotary_embedding_transform(q.w, coef3);
}
#ifdef ENABLE_BF16
template<typename RopeInit>
__device__ __inline__ void apply_rotary_embedding(__nv_bfloat162& q,
int tid,
int rot_embed_dim,
int t_step,
float base,
const RopeInit &rope_init)
{
if (2 * tid >= rot_embed_dim) {
return;
}
float2 coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base, rope_init);
q = rotary_embedding_transform(q, coef);
}
template<typename RopeInit>
__device__ __inline__ void apply_rotary_embedding(bf16_4_t& q,
int tid,
int rot_embed_dim,
int t_step,
float base,
const RopeInit &rope_init) {
if (4 * tid >= rot_embed_dim) {
return;
}
float2 coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base, rope_init);
float2 coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base, rope_init);
q.x = rotary_embedding_transform(q.x, coef0);
q.y = rotary_embedding_transform(q.y, coef1);
}
template<typename RopeInit>
__device__ __inline__ void apply_rotary_embedding(bf16_8_t& q,
int tid,
int rot_embed_dim,
int t_step,
float base,
const RopeInit &rope_init)
{
if (8 * tid >= rot_embed_dim) {
return;
}
float2 coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base, rope_init);
float2 coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base, rope_init);
float2 coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base, rope_init);
float2 coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base, rope_init);
q.x = rotary_embedding_transform(q.x, coef0);
q.y = rotary_embedding_transform(q.y, coef1);
q.z = rotary_embedding_transform(q.z, coef2);
q.w = rotary_embedding_transform(q.w, coef3);
}
#endif // ENABLE_BF16
template<typename RopeInit, typename scalar_t, typename vector_t>
__device__ __inline__ void normal_rope(vector_t& x,
scalar_t* smem,
const int tidx,
const int seqidx,
const int dim,
const float base,
const RopeInit& rope_init,
const int offset = 0)
{
const int vec_size = vector_size<scalar_t, vector_t>::size;
const int rope_idx = tidx * vec_size - offset;
const bool work = (rope_idx >= 0 && rope_idx < dim);
const int rope_tidx = rope_idx / vec_size;
if (work) {
reinterpret_cast<vector_t*>(smem)[rope_tidx] = x;
}
__syncthreads();
if (work) {
RotaryHalfRead(x, smem, rope_tidx, dim / 2);
apply_rotary_embedding(x, rope_tidx, dim, seqidx, base, rope_init);
RotaryHalfWrite(x, smem, rope_tidx, dim / 2);
}
__syncthreads();
if (work) {
x = reinterpret_cast<vector_t*>(smem)[rope_tidx];
}
}
template<typename RopeInit, typename scalar_t, typename vector_t>
__device__ __inline__ void glm2_rope(vector_t& x,
const int tidx,
const int seqidx,
const int dim,
const float base,
const RopeInit& rope_init)
{
const int vec_size = vector_size<scalar_t, vector_t>::size;
const bool work = (tidx * vec_size < dim);
if (work) {
apply_rotary_embedding(x, tidx, dim, seqidx, base, rope_init);
}
}
__device__ __inline__ float get_dynamic_ntk_base(const int dim, const float base, const int seq_len, const float scale, const int max_pos) {
float new_base = base * 1.0f * pow(((scale * seq_len / max_pos) - (scale - 1)), (dim / (dim - 2.0f)));
return new_base;
}
__device__ __inline__ float get_qwen_dynamic_ntk_base(const int dim, const float base, const int seq_len, const int max_pos) {
float context_value = logf((float)seq_len / max_pos) / logf(2.0) + 1.0;
float ntk_scalar = pow(2.0, ceil(context_value)) - 1;
ntk_scalar = max(ntk_scalar, 1.0);
float new_base = base * pow((float)ntk_scalar, (float)dim / (dim - 2));
return new_base;
}
template<typename scalar_t, typename vector_t, RopeStyle ROPE_STYLE>
__device__ inline void apply_rope(RopeConfig rope_config,
vector_t& x,
scalar_t* smem,
int tidx,
int seqidx,
int seq_len)
{
auto base = rope_config.base;
auto dim = rope_config.dim;
switch (ROPE_STYLE) {
case RopeStyle::No:
break;
case RopeStyle::Base:
normal_rope(x, smem, tidx, seqidx, dim, base, LinearScaleRope{rope_config.scale});
break;
case RopeStyle::Glm2:
// only do rotary embedding for [..., d / 2]
glm2_rope<LinearScaleRope, scalar_t, vector_t>(x, tidx, seqidx, dim / 2, base, LinearScaleRope{rope_config.scale});
break;
case RopeStyle::DynamicNTK:
if (seq_len > rope_config.max_pos) {
base = get_dynamic_ntk_base(dim, base, seq_len, rope_config.scale, rope_config.max_pos);
}
normal_rope(x, smem, tidx, seqidx, dim, base, DefaultRope{});
break;
case RopeStyle::Yarn:
normal_rope(x, smem, tidx, seqidx, dim, base,
YarnRope{rope_config.dim, rope_config.base, rope_config.max_pos,
rope_config.factor1, rope_config.factor2, rope_config.scale,
rope_config.extrapolation_factor, rope_config.mscale},
rope_config.offset);
break;
case RopeStyle::QwenDynamicNTK:
if (seq_len > rope_config.max_pos) {
base = get_qwen_dynamic_ntk_base(dim, base, seq_len, rope_config.max_pos);
}
normal_rope(x, smem, tidx, seqidx, dim, base, DefaultRope{});
break;
case RopeStyle::Llama3:
normal_rope(x, smem, tidx, seqidx, dim, base,
Llama3Rope{rope_config.factor1, rope_config.factor2, rope_config.scale,
rope_config.max_pos});
break;
case RopeStyle::Mrope:
normal_rope(x, smem, tidx, seqidx, dim, base, LinearScaleRope{rope_config.scale});
break;
default:
break;
}
}
template<typename scalar_t, typename vector_t, RopeStyle ROPE_STYLE>
__device__ inline void context_rope(RopeConfig rope_config,
vector_t& q,
vector_t& k,
scalar_t* smem,
int tidx,
int seqidx,
int position_id,
int seq_len,
int input_len,
bool PREFIX_PROMPT,
int prefix_prompt_length,
int count_length)
{
if (PREFIX_PROMPT && count_length) {
input_len = input_len + prefix_prompt_length;
seqidx = seqidx + prefix_prompt_length;
}
if (position_id > 0) {
seqidx = position_id;
}
apply_rope<scalar_t, vector_t, ROPE_STYLE>(
rope_config, q, smem, tidx, seqidx, seq_len);
apply_rope<scalar_t, vector_t, ROPE_STYLE>(
rope_config, k, smem, tidx, seqidx, seq_len);
}
template<typename scalar_t, typename vector_t, RopeStyle ROPE_STYLE>
__device__ inline void attention_rope(RopeConfig rope_config,
vector_t& q,
vector_t& k,
scalar_t* smem,
int tidx,
int tlength,
int timestep,
int seq_len,
int position_id,
int input_len,
#pragma nv_diagnostic push
#pragma nv_diag_suppress 550
[[maybe_unused]] int prefix_prompt_length,
#pragma nv_diagnostic pop
int count_prefix_length,
bool handle_kv)
{
if (count_prefix_length) {
prefix_prompt_length = 0;
}
if (position_id > 0) {
tlength = position_id;
}
if constexpr (ROPE_STYLE == RopeStyle::Glm2) {
tlength = tlength - prefix_prompt_length;
}
apply_rope<scalar_t, vector_t, ROPE_STYLE>(
rope_config, q, smem, tidx, tlength, seq_len);
if (handle_kv) {
apply_rope<scalar_t, vector_t, ROPE_STYLE>(
rope_config, k, smem, tidx, tlength, seq_len);
}
}
template<typename scalar_t, typename vector_t>
__global__ void launchApplyRopeKernel(scalar_t* input,
RopeConfig rope_config,
int head_num,
int head_size,
int seq_len,
const int* padding_offset,
const int* prefill_length) {
extern __shared__ __align__(sizeof(float2)) char smem[];
const int token_idx = blockIdx.x;
const int head_num_idx = blockIdx.y;
const int tidx = threadIdx.x;
const int head_size_idx = tidx * 2;
const int token_padding_offset = padding_offset == nullptr ? 0 : padding_offset[token_idx];
const int tgt_token_idx = token_idx + token_padding_offset;
const int batch_idx = tgt_token_idx / seq_len;
const int token_prefill_length = prefill_length == nullptr ? 0 : prefill_length[batch_idx];
const int seq_idx = tgt_token_idx % seq_len + token_prefill_length;
const bool work = (head_num_idx < head_num && head_size_idx < head_size);
if (work) {
vector_t x;
const int offset = token_idx * head_num * head_size + head_num_idx * head_size + head_size_idx;
x = *reinterpret_cast<vector_t*>(&input[offset]);
FT_ROPE_SWITCH(rope_config.style, ROPE_STYLE, [&] {
apply_rope<scalar_t, vector_t, ROPE_STYLE>(
rope_config, x, reinterpret_cast<scalar_t*>(smem), tidx, seq_idx, seq_len);
});
*reinterpret_cast<vector_t*>(&input[offset]) = x;
}
}
} // namespace rtp_llm