maga_transformer/cpp/kernels/_cast_to_int8.h (54 lines of code) (raw):
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int8_t cast_to_int8(float val) {
// https://github.com/vllm-project/vllm/blob/c5832d2ae9431a1672d547c232ec46b1a9051ff0/csrc/quantization/compressed_tensors/int8_quant_kernels.cu#L8-L25
#ifdef USING_ROCM
static const float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static const float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
// round
float dst = std::nearbyint(val);
// saturate
dst = std::clamp(dst, i8_min, i8_max);
return static_cast<int8_t>(dst);
#else
union {
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int32_t cast_to_int8(float2 val) {
union {
int8_t int8[2];
int32_t int32;
};
int8[0] = cast_to_int8(val.x);
int8[1] = cast_to_int8(val.y);
return int32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int32_t cast_to_int8(float4 val) {
union {
int8_t int8[4];
int32_t int32;
};
int8[0] = cast_to_int8(val.x);
int8[1] = cast_to_int8(val.y);
int8[2] = cast_to_int8(val.z);
int8[3] = cast_to_int8(val.w);
return int32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int64_t cast_to_int8(Float8_ val) {
union {
int8_t int8[8];
int64_t int64;
};
int8[0] = cast_to_int8(val.x.x);
int8[1] = cast_to_int8(val.x.y);
int8[2] = cast_to_int8(val.y.x);
int8[3] = cast_to_int8(val.y.y);
int8[4] = cast_to_int8(val.z.x);
int8[5] = cast_to_int8(val.z.y);
int8[6] = cast_to_int8(val.w.x);
int8[7] = cast_to_int8(val.w.y);
return int64;
}