csrc/mps_kernels.metal (103 lines of code) (raw):
#include <metal_stdlib>
using namespace metal;
#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096
template<bool STOCHASTIC>
static unsigned char quantize_scalar(
float rand,
device float* code,
float x)
{
int pivot = 127;
int upper_pivot = 255;
int lower_pivot = 0;
float lower = -1.0f;
float upper = 1.0f;
float val = code[pivot];
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1)
{
if(x > val)
{
lower_pivot = pivot;
lower = val;
pivot+=i;
}
else
{
upper_pivot = pivot;
upper = val;
pivot-=i;
}
val = code[pivot];
}
if(upper_pivot == 255)
upper = code[upper_pivot];
if(lower_pivot == 0)
lower = code[lower_pivot];
if(!STOCHASTIC)
{
if(x > val)
{
float midpoint = (upper+val)*0.5f;
if(x > midpoint)
{
return upper_pivot;
}
else
return pivot;
}
else
{
float midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot;
else
return pivot;
}
}
else
{
if(x > val)
{
float dist_to_upper = fabs(upper-x);
float dist_full = upper-val;
if(rand >= dist_to_upper/dist_full) return upper_pivot;
else return pivot;
}
else
{
float dist_to_lower = fabs(lower-x);
float dist_full = val-lower;
if(rand >= dist_to_lower/dist_full) return lower_pivot;
else return pivot;
}
}
}
kernel void quantize(device float* code [[buffer(0)]],
device float* A [[buffer(1)]],
device uchar* out [[buffer(2)]],
constant uint& n [[buffer(3)]],
uint id [[thread_position_in_grid]]) {
const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK;
const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK);
float vals[NUM];
uchar qvals[NUM];
for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) {
valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint j = 0; j < valid_items; j++) {
vals[j] = A[i + j];
}
for (uint j = 0; j < valid_items; j++) {
qvals[j] = quantize_scalar<false>(0.0f, code, vals[j]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint j = 0; j < valid_items; j++) {
out[i + j] = qvals[j];
}
}
}