in crates/ratchet-core/src/ops/matmul/gemm.rs [614:746]
fn render_vectorized<P: WgslPrimitive>(
&self,
mut kernel_builder: WgslKernelBuilder,
) -> Result<KernelSource, OperationError> {
const ROW_PER_THREAD: usize = 4;
const TILE_DIM: usize = 32;
let accessor = P::render_type();
let W = P::W;
let fp32_accessor = match W {
1 => Scalar::<f32>::render_type(),
2 => Vec2::<f32>::render_type(),
4 => Vec4::<f32>::render_type(),
_ => panic!("Unsupported W"),
};
let T_W = TILE_DIM / W;
kernel_builder.write_global(wgsl! {
var<workgroup> mm_Asub: array<array<'accessor, 'T_W>, 'TILE_DIM>;
var<workgroup> mm_Bsub: array<array<'accessor, 'T_W>, 'TILE_DIM>;
});
kernel_builder.write_main(wgsl! {
let batch = i32(global_invocation_id.z);
let batchA = batch % metadata.lhs_shape[0];
let batchB = batch % metadata.rhs_shape[0];
let localRow = i32(local_invocation_id.y);
let tileRow = localRow * 'ROW_PER_THREAD;
let tileCol = i32(local_invocation_id.x);
let globalRow = i32(global_invocation_id.y) * 'ROW_PER_THREAD;
let globalCol = i32(global_invocation_id.x) * 'W;
let numTiles = (metadata.dim_inner - 1) / 'TILE_DIM + 1;
var kStart = 0;
var acc: array<'fp32_accessor, 'ROW_PER_THREAD>;
// Loop over shared dimension.
let tileRowB = localRow * 'ROW_PER_THREAD;
});
let load_a_inner = match self.lhs.dt() {
DType::F32 | DType::F16 => {
wgsl! { mm_Asub[inputRow][inputCol] = mm_readA(batchA, globalRow + innerRow, kStart + inputCol * 'W); }
}
DType::Q8_0F(_) | DType::Q8_0H(_) => {
wgsl! {
let curRow = globalRow + innerRow;
let curCol = kStart + inputCol * 'W;
let absmax = getAbsMax(batchA, curRow, curCol);
mm_Asub[inputRow][inputCol] = mm_readA(batchA, curRow, curCol) * absmax;
}
}
_ => panic!("Unsupported dtype"),
};
let load_a = wgsl! {
// Load one tile of A into local memory.
for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
let inputRow = tileRow + innerRow;
let inputCol = tileCol;
'load_a_inner
}
};
let load_b = wgsl! {
// Load one tile of B into local memory.
for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
let inputRow = tileRowB + innerRow;
let inputCol = tileCol;
mm_Bsub[inputRow][inputCol] = mm_readB(batchB, kStart + inputRow, globalCol);
}
};
let mut outer_body = WgslFragment::new(128);
let mut inner_body = WgslFragment::new(128);
for c in 0..W {
let bIdent = format!("BCached{}", c);
inner_body.write(wgsl! {
acc[i] += 'fp32_accessor('accessor(ACached['c]) * 'bIdent);
});
outer_body.write(wgsl! { let 'bIdent = mm_Bsub[bidx + 'c][tileCol]; });
}
let compute_acc = wgsl! {
// Compute acc values for a single thread.
for (var k = 0; k < 'T_W; k++) {
let bidx = k * 'W;
'outer_body
for (var i = 0; i < 'ROW_PER_THREAD; i++) {
let ACached = mm_Asub[tileRow + i][k];
'inner_body
}
}
};
kernel_builder.write_main(wgsl! {
for (var t = 0; t < numTiles; t++) {
'load_a
'load_b
kStart = kStart + 'TILE_DIM;
workgroupBarrier();
'compute_acc
workgroupBarrier();
}
var val: 'accessor;
});
let bias_val = if self.bias.is_some() {
wgsl! { bias[globalCol / 'W]; }
} else {
wgsl! { 0.0; }
};
for i in 0..ROW_PER_THREAD {
kernel_builder.write_main(wgsl! {
val = 'accessor(acc['i]) + 'bias_val
mm_write(batch, globalRow + 'i, globalCol, val);
});
}
let x = kernel_builder.build()?;
Ok(x)
}