fn render_vectorized()

in crates/ratchet-core/src/ops/matmul/gemm.rs [614:746]


    fn render_vectorized<P: WgslPrimitive>(
        &self,
        mut kernel_builder: WgslKernelBuilder,
    ) -> Result<KernelSource, OperationError> {
        const ROW_PER_THREAD: usize = 4;
        const TILE_DIM: usize = 32;

        let accessor = P::render_type();
        let W = P::W;

        let fp32_accessor = match W {
            1 => Scalar::<f32>::render_type(),
            2 => Vec2::<f32>::render_type(),
            4 => Vec4::<f32>::render_type(),
            _ => panic!("Unsupported W"),
        };

        let T_W = TILE_DIM / W;
        kernel_builder.write_global(wgsl! {
            var<workgroup> mm_Asub: array<array<'accessor, 'T_W>, 'TILE_DIM>;
            var<workgroup> mm_Bsub: array<array<'accessor, 'T_W>, 'TILE_DIM>;
        });

        kernel_builder.write_main(wgsl! {
            let batch = i32(global_invocation_id.z);
            let batchA = batch % metadata.lhs_shape[0];
            let batchB = batch % metadata.rhs_shape[0];

            let localRow = i32(local_invocation_id.y);
            let tileRow = localRow * 'ROW_PER_THREAD;
            let tileCol = i32(local_invocation_id.x);

            let globalRow = i32(global_invocation_id.y) * 'ROW_PER_THREAD;
            let globalCol = i32(global_invocation_id.x) * 'W;

            let numTiles = (metadata.dim_inner - 1) / 'TILE_DIM + 1;
            var kStart = 0;

            var acc: array<'fp32_accessor, 'ROW_PER_THREAD>;

            // Loop over shared dimension.
            let tileRowB = localRow * 'ROW_PER_THREAD;
        });

        let load_a_inner = match self.lhs.dt() {
            DType::F32 | DType::F16 => {
                wgsl! { mm_Asub[inputRow][inputCol] = mm_readA(batchA, globalRow + innerRow, kStart + inputCol * 'W); }
            }
            DType::Q8_0F(_) | DType::Q8_0H(_) => {
                wgsl! {
                    let curRow = globalRow + innerRow;
                    let curCol = kStart + inputCol * 'W;

                    let absmax = getAbsMax(batchA, curRow, curCol);
                    mm_Asub[inputRow][inputCol] = mm_readA(batchA, curRow, curCol) * absmax;
                }
            }
            _ => panic!("Unsupported dtype"),
        };

        let load_a = wgsl! {
            // Load one tile of A into local memory.
            for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
                let inputRow = tileRow + innerRow;
                let inputCol = tileCol;

                'load_a_inner
            }
        };

        let load_b = wgsl! {
            // Load one tile of B into local memory.
            for (var innerRow = 0; innerRow < 'ROW_PER_THREAD; innerRow++) {
                let inputRow = tileRowB + innerRow;
                let inputCol = tileCol;

                mm_Bsub[inputRow][inputCol] = mm_readB(batchB, kStart + inputRow, globalCol);
            }
        };

        let mut outer_body = WgslFragment::new(128);
        let mut inner_body = WgslFragment::new(128);
        for c in 0..W {
            let bIdent = format!("BCached{}", c);
            inner_body.write(wgsl! {
                acc[i] += 'fp32_accessor('accessor(ACached['c]) * 'bIdent);
            });
            outer_body.write(wgsl! { let 'bIdent = mm_Bsub[bidx + 'c][tileCol]; });
        }

        let compute_acc = wgsl! {
            // Compute acc values for a single thread.
            for (var k = 0; k < 'T_W; k++) {
              let bidx = k * 'W;
              'outer_body
              for (var i = 0; i < 'ROW_PER_THREAD; i++) {
                let ACached = mm_Asub[tileRow + i][k];
                'inner_body
              }
            }
        };

        kernel_builder.write_main(wgsl! {
            for (var t = 0; t < numTiles; t++) {
                'load_a
                'load_b

                kStart = kStart + 'TILE_DIM;
                workgroupBarrier();

                'compute_acc
                workgroupBarrier();
            }

            var val: 'accessor;
        });

        let bias_val = if self.bias.is_some() {
            wgsl! { bias[globalCol / 'W]; }
        } else {
            wgsl! { 0.0; }
        };

        for i in 0..ROW_PER_THREAD {
            kernel_builder.write_main(wgsl! {
                val = 'accessor(acc['i]) + 'bias_val
                mm_write(batch, globalRow + 'i, globalCol, val);
            });
        }

        let x = kernel_builder.build()?;
        Ok(x)
    }