in crates/ratchet-core/src/ops/matmul/workgroup_gemv.rs [221:371]
fn render<P: WgslPrimitive>(
&self,
inplace: bool,
dst: &Tensor,
workgroup_size: &WorkgroupSize,
) -> Result<KernelSource, OperationError> {
let device = dst.device().try_gpu()?;
let mut kernel_builder = WgslKernelBuilder::new(
workgroup_size.clone(),
rvec![
BuiltIn::GlobalInvocationId,
BuiltIn::LocalInvocationId,
BuiltIn::WorkgroupId,
],
device.compute_features().clone(),
);
self.register_bindings::<P>(&mut kernel_builder, inplace)?;
let n = P::W;
let fp32_accessor = match n {
1 => "f32",
2 => "vec2<f32>",
4 => "vec4<f32>",
_ => unimplemented!(),
};
let scalar = P::T::DT;
let zero = P::T::zero().render();
kernel_builder.render_metadata(&self.metadata(dst, &self.kernel_element(dst))?);
kernel_builder.write_unpack(self.lhs.dt());
let work_size = (workgroup_size.x * workgroup_size.y / (n as u32)).render();
kernel_builder.write_global(wgsl! {
var<workgroup> work: array<'fp32_accessor, 'work_size>;
});
let (TILE_X, _) = self.spec.heuristic.as_workgroup_size();
let A_FIT = self.spec.lhs_shape()[1] % TILE_X == 0;
let readA = match (A_FIT, self.lhs.dt()) {
(true, DType::F32) | (true, DType::F16) => {
wgsl! {
fn readA(batch: i32, row: i32, col: i32) -> 'scalar {
return A[dot(metadata.lhs_strides, vec3<i32>(batch, row, col))];
}
}
}
(false, DType::F32) | (false, DType::F16) => {
wgsl! {
fn readA(batch: i32, row: i32, col: i32) -> 'scalar {
var val = 'zero;
if (row <= metadata.lhs_shape.y) {
val = A[dot(metadata.lhs_strides, vec3<i32>(batch, row, col))];
}
return val;
}
}
}
(true, DType::Q8_0F(_)) | (true, DType::Q8_0H(_)) => {
wgsl! {
fn readA(batch: i32, row: i32, col: i32) -> vec4<'scalar> {
return unpack(A[dot(metadata.lhs_strides, vec3<i32>(batch, row, col))]);
}
}
}
_ => unimplemented!(),
};
kernel_builder.write_global(readA);
kernel_builder.write_main(wgsl! { let row = i32(global_invocation_id.x); });
kernel_builder.write_main(wgsl! {
let batch = i32(global_invocation_id.z);
let batchA = batch % metadata.lhs_shape.x;
let batchB = batch % metadata.rhs_shape.x;
});
kernel_builder.write_main(wgsl! {
let aOffset = metadata.lhs_strides.x * batchA / 'n;
let bOffset = metadata.rhs_strides.x * batchB / 'n;
let outOffset = metadata.dst_strides.x * batch / 'n;
});
kernel_builder.write_main(wgsl! { var sum = 'fp32_accessor(0.0); });
kernel_builder
.write_main(wgsl! { let aIndex = aOffset + row * metadata.lhs_strides.y / 'n; });
let workgroup_size_y = workgroup_size.y;
let main_loop = match self.lhs.dt() {
DType::Q8_0F(_) | DType::Q8_0H(_) => {
wgsl! {
let sIndex = (aOffset / 4) + row * metadata.lhs_strides.y / 32;
for (var k = i32(global_invocation_id.y); k < metadata.dim_inner / 4; k+='workgroup_size_y / 4) {
sum += 'fp32_accessor(unpack(A[aIndex + k]) * scale[sIndex + (k/8)] * X[k]);
}
}
}
_ => {
wgsl! {
for (var k = i32(global_invocation_id.y); k < metadata.dim_inner; k+='workgroup_size_y) {
sum += 'fp32_accessor(readA(batchA, row, k) * X[bOffset + k]);
}
}
}
};
kernel_builder.write_main(main_loop);
let workgroup_size_x = workgroup_size.x.render();
let workgroup_size_y = workgroup_size.y.render();
kernel_builder.write_main(wgsl! {
let rows = 'workgroup_size_x;
let cols = 'workgroup_size_y / 'n;
let ii = u32(local_invocation_id.x);
let jj = u32(local_invocation_id.y);
work[ii + rows * jj] = sum;
workgroupBarrier();
// Reduce sums in log2(cols) steps
for (var s = u32(cols) / 2u; s > 0u; s >>= 1u) {
if (jj < s) {
work[ii + rows * jj] += work[ii + rows * (jj + s)];
}
workgroupBarrier();
}
});
let bias = if self.bias.is_some() {
wgsl! { bias[row] }
} else {
wgsl! { 0. }
};
let finalizer = match P::W {
4 | 2 => {
wgsl! { result[outOffset + row] = 'scalar(dot(work[ii], 'fp32_accessor(1.0)) + f32('bias));}
}
1 => wgsl! { result[outOffset + row] = 'scalar(work[ii] + f32('bias)); },
_ => unimplemented!(),
};
kernel_builder.write_main(wgsl! {
if (jj == 0) {
'finalizer
}
});
Ok(kernel_builder.build()?)
}