in crates/ratchet-core/src/ops/matmul/gemm.rs [462:612]
fn render_scalar<P: WgslPrimitive>(
&self,
mut kernel_builder: WgslKernelBuilder,
) -> Result<KernelSource, OperationError> {
const ROW_PER_THREAD: usize = 4;
const TILE_DIM: usize = 32;
let accessor = P::render_type();
let dt = P::T::DT;
let W = P::W;
let T_W = TILE_DIM / W;
kernel_builder.write_global(wgsl! {
var<workgroup> mm_Asub: array<array<'accessor, 'T_W>, 'TILE_DIM>;
var<workgroup> mm_Bsub: array<array<'accessor, 'T_W>, 'TILE_DIM>;
});
kernel_builder.write_main(wgsl! {
let batch = i32(global_invocation_id.z);
let batchA = batch % metadata.lhs_shape[0];
let batchB = batch % metadata.rhs_shape[0];
let tileRow = i32(local_invocation_id.y) * 'ROW_PER_THREAD;
let tileCol = i32(local_invocation_id.x) * 4;
let globalRowStart = i32(workgroup_id.y) * 'T_W;
let globalRow = i32(global_invocation_id.y) * 'ROW_PER_THREAD;
let globalCol = i32(global_invocation_id.x) * 'ROW_PER_THREAD;
let numTiles = (metadata.dim_inner - 1) / 'TILE_DIM + 1;
var kStart = 0;
//ALWAYS ACCUM IN FP32
var acc: array<array<f32, 'ROW_PER_THREAD>, 'ROW_PER_THREAD>;
let tileRowA = i32(local_invocation_id.y) * 'ROW_PER_THREAD;
let tileColA = i32(local_invocation_id.x) * 'ROW_PER_THREAD;
let tileRowB = i32(local_invocation_id.y) * 'ROW_PER_THREAD;
// Loop over shared dimension.
});
let a_inner = match self.lhs.dt() {
DType::F32 | DType::F16 => {
wgsl! {
for (var innerCol = 0; innerCol < 'ROW_PER_THREAD; innerCol++) {
let inputRow = tileRowA + innerRow;
let inputCol = tileColA + innerCol;
mm_Asub[inputRow][inputCol] = mm_readA(batchA,
globalRowStart + inputRow,
kStart + inputCol);
}
}
}
DType::Q8_0F(_) | DType::Q8_0H(_) => {
let mut inner = wgsl! {
let curRow = globalRow + innerRow;
let curCol = kStart + i32(local_invocation_id.x) * 4;
let absmax = getAbsMax(batchA, curRow, curCol);
let val = mm_readA(batchA, curRow, curCol) * absmax;
};
for i in 0..4 {
inner.push_str(
&wgsl! { mm_Asub[tileRowA + innerRow][tileColA + 'i] = val['i]; },
);
}
inner
}
_ => panic!("Unsupported dtype"),
};
let load_a = wgsl! {
for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
'a_inner
}
};
let load_b = wgsl! {
// Load one tile of B into local memory.
for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
for (var innerCol = 0; innerCol < 'ROW_PER_THREAD; innerCol++) {
let inputRow = tileRowB + innerRow;
let inputCol = tileCol + innerCol;
mm_Bsub[inputRow][inputCol] = mm_readB(batchB, kStart + inputRow, globalCol + innerCol);
}
}
};
let compute_acc = wgsl! {
// Compute acc values for a single thread.
for (var k = 0; k < 'T_W; k++) {
let bidx = k * 'W;
let BCached0 = mm_Bsub[bidx][tileCol + 0];
let BCached1 = mm_Bsub[bidx][tileCol + 1];
let BCached2 = mm_Bsub[bidx][tileCol + 2];
let BCached3 = mm_Bsub[bidx][tileCol + 3];
for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
let ACached = mm_Asub[tileRow + innerRow][k];
acc[innerRow][0] += f32(ACached * BCached0);
acc[innerRow][1] += f32(ACached * BCached1);
acc[innerRow][2] += f32(ACached * BCached2);
acc[innerRow][3] += f32(ACached * BCached3);
}
}
};
kernel_builder.write_main(wgsl! {
for (var t = 0; t < numTiles; t++) {
'load_a
'load_b
kStart = kStart + 'TILE_DIM;
workgroupBarrier();
'compute_acc
workgroupBarrier();
}
var val: 'accessor;
});
for row in 0..ROW_PER_THREAD {
for col in 0..ROW_PER_THREAD {
let bias_val = if self.bias.is_some() {
if self.trans_dst {
wgsl! { bias[globalRow + 'row] }
} else {
wgsl! { bias[globalCol + 'col] }
}
} else {
wgsl! { 0. }
};
let writer = if self.trans_dst {
wgsl! { mm_write(batch, globalCol + 'col, globalRow + 'row, val); }
} else {
wgsl! { mm_write(batch, globalRow + 'row, globalCol + 'col, val); }
};
kernel_builder.write_main(wgsl! {
val = 'dt(acc['row]['col]) + 'bias_val;
'writer
});
}
}
Ok(kernel_builder.build()?)
}