fn render_scalar()

in crates/ratchet-core/src/ops/matmul/gemm.rs [462:612]


    fn render_scalar<P: WgslPrimitive>(
        &self,
        mut kernel_builder: WgslKernelBuilder,
    ) -> Result<KernelSource, OperationError> {
        const ROW_PER_THREAD: usize = 4;
        const TILE_DIM: usize = 32;

        let accessor = P::render_type();
        let dt = P::T::DT;
        let W = P::W;
        let T_W = TILE_DIM / W;
        kernel_builder.write_global(wgsl! {
            var<workgroup> mm_Asub: array<array<'accessor, 'T_W>, 'TILE_DIM>;
            var<workgroup> mm_Bsub: array<array<'accessor, 'T_W>, 'TILE_DIM>;
        });

        kernel_builder.write_main(wgsl! {
            let batch = i32(global_invocation_id.z);
            let batchA = batch % metadata.lhs_shape[0];
            let batchB = batch % metadata.rhs_shape[0];

            let tileRow = i32(local_invocation_id.y) * 'ROW_PER_THREAD;
            let tileCol = i32(local_invocation_id.x) * 4;

            let globalRowStart = i32(workgroup_id.y) * 'T_W;
            let globalRow = i32(global_invocation_id.y) * 'ROW_PER_THREAD;
            let globalCol = i32(global_invocation_id.x) * 'ROW_PER_THREAD;

            let numTiles = (metadata.dim_inner - 1) / 'TILE_DIM + 1;
            var kStart = 0;

            //ALWAYS ACCUM IN FP32
            var acc: array<array<f32, 'ROW_PER_THREAD>, 'ROW_PER_THREAD>;

            let tileRowA = i32(local_invocation_id.y) * 'ROW_PER_THREAD;
            let tileColA = i32(local_invocation_id.x) * 'ROW_PER_THREAD;
            let tileRowB = i32(local_invocation_id.y) * 'ROW_PER_THREAD;
            // Loop over shared dimension.
        });

        let a_inner = match self.lhs.dt() {
            DType::F32 | DType::F16 => {
                wgsl! {
                    for (var innerCol = 0; innerCol < 'ROW_PER_THREAD; innerCol++) {
                        let inputRow = tileRowA + innerRow;
                        let inputCol = tileColA + innerCol;

                        mm_Asub[inputRow][inputCol] = mm_readA(batchA,
                            globalRowStart + inputRow,
                            kStart + inputCol);
                    }
                }
            }
            DType::Q8_0F(_) | DType::Q8_0H(_) => {
                let mut inner = wgsl! {
                    let curRow = globalRow + innerRow;
                    let curCol = kStart + i32(local_invocation_id.x) * 4;

                    let absmax = getAbsMax(batchA, curRow, curCol);
                    let val = mm_readA(batchA, curRow, curCol) * absmax;
                };
                for i in 0..4 {
                    inner.push_str(
                        &wgsl! { mm_Asub[tileRowA + innerRow][tileColA + 'i] = val['i]; },
                    );
                }
                inner
            }
            _ => panic!("Unsupported dtype"),
        };

        let load_a = wgsl! {
            for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
                'a_inner
            }
        };

        let load_b = wgsl! {
            // Load one tile of B into local memory.
            for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
                for (var innerCol = 0; innerCol < 'ROW_PER_THREAD; innerCol++) {
                    let inputRow = tileRowB + innerRow;
                    let inputCol = tileCol + innerCol;

                    mm_Bsub[inputRow][inputCol] = mm_readB(batchB, kStart + inputRow, globalCol + innerCol);
                }
            }
        };

        let compute_acc = wgsl! {
            // Compute acc values for a single thread.
            for (var k = 0; k < 'T_W; k++) {
              let bidx = k * 'W;
              let BCached0 = mm_Bsub[bidx][tileCol + 0];
              let BCached1 = mm_Bsub[bidx][tileCol + 1];
              let BCached2 = mm_Bsub[bidx][tileCol + 2];
              let BCached3 = mm_Bsub[bidx][tileCol + 3];
              for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
                let ACached = mm_Asub[tileRow + innerRow][k];
                acc[innerRow][0] += f32(ACached * BCached0);
                acc[innerRow][1] += f32(ACached * BCached1);
                acc[innerRow][2] += f32(ACached * BCached2);
                acc[innerRow][3] += f32(ACached * BCached3);
              }
            }
        };

        kernel_builder.write_main(wgsl! {
            for (var t = 0; t < numTiles; t++) {

                'load_a

                'load_b

                kStart = kStart + 'TILE_DIM;
                workgroupBarrier();

                'compute_acc
                workgroupBarrier();
            }

            var val: 'accessor;
        });

        for row in 0..ROW_PER_THREAD {
            for col in 0..ROW_PER_THREAD {
                let bias_val = if self.bias.is_some() {
                    if self.trans_dst {
                        wgsl! { bias[globalRow + 'row] }
                    } else {
                        wgsl! { bias[globalCol + 'col] }
                    }
                } else {
                    wgsl! { 0. }
                };

                let writer = if self.trans_dst {
                    wgsl! { mm_write(batch, globalCol + 'col, globalRow + 'row, val); }
                } else {
                    wgsl! { mm_write(batch, globalRow + 'row, globalCol + 'col, val); }
                };

                kernel_builder.write_main(wgsl! {
                    val = 'dt(acc['row]['col]) + 'bias_val;
                    'writer
                });
            }
        }

        Ok(kernel_builder.build()?)
    }