fn render()

in crates/ratchet-core/src/ops/matmul/quantized.rs [138:278]


    fn render<P: WgslPrimitive>(
        &self,
        inplace: bool,
        dst: &Tensor,
        workgroup_size: &WorkgroupSize,
    ) -> Result<KernelSource, OperationError> {
        let device = dst.device().try_gpu()?;
        let mut kernel_builder = WgslKernelBuilder::new(
            workgroup_size.clone(),
            rvec![
                BuiltIn::GlobalInvocationId,
                BuiltIn::LocalInvocationId,
                BuiltIn::WorkgroupId
            ],
            device.compute_features().clone(),
        );
        self.register_bindings::<P>(&mut kernel_builder, inplace)?;

        kernel_builder.render_metadata(&QuantizedMeta { dummy: 0 });

        /* Extract the 16 x 6 bit values scales-mins pairs. The
         * encoding of those values is odd because of performance
         * reasons:
         *
         *  dddddddd dddddddd dddddddd dddddddd mmmmmmmm mmmmmmmm
         *  44000000|55111111|66222222|77333333|44000000|55111111
         *
         *  mmmmmmmm mmmmmmmm mmmmdddd mmmmdddd mmmmdddd mmmmdddd
         *  66222222|77333333|44444444|55555555|66666666|77777777
         *
         * In the above diagram you can see the 12 bytes and the
         * scales/mins 6 bits encodings. */
        kernel_builder.write_global(wgsl! {
           fn extract_subblock_first_four(soffset: u32, pair_idx: u32) -> vec2<u32> {
                let s0 = scales[soffset]; //first 4 bytes
                let s1 = scales[soffset + 1u];//bytes 4-7
                let pair_bit_offset = (8u * pair_idx);
                return vec2<u32>((s0 >> pair_bit_offset) & 63u, (s1 >> pair_bit_offset) & 63u);
           }

           fn extract_subblock_latter_four(soffset: u32, pair_idx: u32) -> vec2<u32> {
                let s0 = scales[soffset]; //first 4 bytes
                let s1 = scales[soffset + 1u];//bytes 4-7
                let s2 = scales[soffset + 2u];//bytes 8-11

                //All of the lower bits are in the last 4 bytes (s2)
                //2bit values are distributed in 1-7 bytes
                //[01][011101] == 29
                let shift = 8u * (pair_idx - 4u);
                let dl = (s2 >> shift & 0xF); //mask 4 bits
                let dh = (s0 >> (6u + shift)) & 0x3; //mask 2 bits

                let ml = (s2 >> (shift + 4u) & 0xF);
                let mh = (s1 >> (6u + shift)) & 0x3;

                return vec2<u32>((dh << 4u) | dl, (mh << 4u) | ml);
           }


           fn get_subblock_scale_min(so: u32, pair_index: u32) -> vec2<u32> {
                return select(
                    extract_subblock_latter_four(so, pair_index),
                    extract_subblock_first_four(so, pair_index),
                    pair_index < 4u
                );
            }
        });

        let TW = 64;
        let TH = 16;
        kernel_builder.write_global(wgsl! {
            var<workgroup> mm_Asub: array<array<f16, 'TW>, 'TH>;
            var<workgroup> mm_Bsub: array<array<f16, 'TW>, 'TH>;
        });

        kernel_builder.write_main(wgsl! {
            var scm = get_subblock_scale_min(0u, 0u);
            var delta = vec4<f16>(f16(scm.x) * d[0]);
            var min = vec4<f16>(f16(scm.y) * dmin[0]);

            //* We process two blocks per time, because each
            //* 32 bytes have 64 weights stored like this:
            //* First 32 weights of the first block are the higher 4
            //* bits of each byte. Second 32 weights of the second
            //* block are lower 4 bits of each byte.
            let packed0 = vec4<u32>(A[0], A[1], A[2], A[3]); // first 16 bytes
            let packed1 = vec4<u32>(A[4], A[5], A[6], A[7]); // second 16 bytes

            let b_mask: u32 = 0x0F0F0F0Fu;
            var b_value_lower: vec4<u32> = unpack4xU8(packed0.x & b_mask);
            var b_value_upper: vec4<u32> = unpack4xU8((packed0.x >> 4) & b_mask);

            var r: array<vec4<f16>, 16>;

            r[0] = fma(vec4<f16>(b_value_lower), delta, -min);
            b_value_lower = unpack4xU8(packed0.y & b_mask);
            r[1] = fma(vec4<f16>(b_value_lower), delta, -min);
            b_value_lower = unpack4xU8(packed0.z & b_mask);
            r[2] = fma(vec4<f16>(b_value_lower), delta, -min);
            b_value_lower = unpack4xU8(packed0.w & b_mask);
            r[3] = fma(vec4<f16>(b_value_lower), delta, -min);
            b_value_lower = unpack4xU8(packed1.x & b_mask);
            r[4] = fma(vec4<f16>(b_value_lower), delta, -min);
            b_value_lower = unpack4xU8(packed1.y & b_mask);
            r[5] = fma(vec4<f16>(b_value_lower), delta, -min);
            b_value_lower = unpack4xU8(packed1.z & b_mask);
            r[6] = fma(vec4<f16>(b_value_lower), delta, -min);
            b_value_lower = unpack4xU8(packed1.w & b_mask);
            r[7] = fma(vec4<f16>(b_value_lower), delta, -min);

            scm = get_subblock_scale_min(0u, 1u);
            delta = vec4<f16>(f16(scm.x) * d[0]);
            min = vec4<f16>(f16(scm.y) * dmin[0]);

            r[8] = fma(vec4<f16>(b_value_upper), delta, -min);
            b_value_upper = unpack4xU8((packed0.y >> 4) & b_mask);
            r[9] = fma(vec4<f16>(b_value_upper), delta, -min);
            b_value_upper = unpack4xU8((packed0.z >> 4) & b_mask);
            r[10] = fma(vec4<f16>(b_value_upper), delta, -min);
            b_value_upper = unpack4xU8((packed0.w >> 4) & b_mask);
            r[11] = fma(vec4<f16>(b_value_upper), delta, -min);
            b_value_upper = unpack4xU8((packed1.x >> 4) & b_mask);
            r[12] = fma(vec4<f16>(b_value_upper), delta, -min);
            b_value_upper = unpack4xU8((packed1.y >> 4) & b_mask);
            r[13] = fma(vec4<f16>(b_value_upper), delta, -min);
            b_value_upper = unpack4xU8((packed1.z >> 4) & b_mask);
            r[14] = fma(vec4<f16>(b_value_upper), delta, -min);
            b_value_upper = unpack4xU8((packed1.w >> 4) & b_mask);
            r[15] = fma(vec4<f16>(b_value_upper), delta, -min);

            for(var i =0; i < 16; i++) {
                for(var j=0; j < 4; j++) {
                    result[i * 4 + j] = r[i][j];
                }
            }
        });

        let x = kernel_builder.build()?;
        println!("{}", x);
        Ok(x)
    }