fn mqa_block()

in candle-transformers/src/models/mobilenetv4.rs [499:626]


fn mqa_block(
    in_channels: usize,
    out_channels: usize,
    heads: usize,
    kernel: usize,
    stride: usize,
    kv_dim: usize,
    kv_stride: usize,
    vb: VarBuilder,
) -> Result<Func<'static>> {
    let down_conv2d_cfg = Conv2dConfig {
        stride: kv_stride,
        padding: kernel / 2,
        groups: in_channels,
        ..Default::default()
    };

    let proj_conv2d_cfg = Conv2dConfig {
        stride,
        ..Default::default()
    };

    let skip_connection = (in_channels == out_channels) && (stride == 1);
    let gamma = vb.get(out_channels, "layer_scale.gamma");
    let norm = batch_norm(out_channels, 1e-5, vb.pp("norm"))?;
    let scale = (kv_dim as f64).powf(-0.5);

    let vb = vb.pp("attn");

    let query_proj = conv2d_no_bias(
        out_channels,
        kv_dim * heads,
        1,
        proj_conv2d_cfg,
        vb.pp("query.proj"),
    )?;

    let key_down_conv = conv2d_no_bias(
        in_channels,
        out_channels,
        kernel,
        down_conv2d_cfg,
        vb.pp("key.down_conv"),
    );
    let key_norm = batch_norm(out_channels, 1e-5, vb.pp("key.norm"));

    let key_proj = conv2d_no_bias(out_channels, kv_dim, 1, proj_conv2d_cfg, vb.pp("key.proj"))?;

    let value_down_conv = conv2d_no_bias(
        in_channels,
        out_channels,
        kernel,
        down_conv2d_cfg,
        vb.pp("value.down_conv"),
    );

    let value_norm = batch_norm(out_channels, 1e-5, vb.pp("value.norm"));
    let value_proj = conv2d_no_bias(
        out_channels,
        kv_dim,
        1,
        proj_conv2d_cfg,
        vb.pp("value.proj"),
    )?;

    let output_proj = conv2d_no_bias(
        kv_dim * heads,
        out_channels,
        1,
        proj_conv2d_cfg,
        vb.pp("output.proj"),
    )?;

    Ok(Func::new(move |xs| {
        let (_, _, h, w) = xs.dims4()?;

        let residual = xs.clone();

        let xs = xs.apply_t(&norm, false)?;

        // Query
        let q = xs.apply(&query_proj)?;

        let q = reshape_query(&q, heads, kv_dim)?;
        let q = (q * scale)?;

        // Keys
        let mut k = xs.clone();

        if let (Ok(kd), Ok(n)) = (&key_down_conv, &key_norm) {
            k = k.apply(kd)?.apply_t(n, false)?;
        }

        let k = k.apply(&key_proj)?;

        let k = reshape_kv(&k)?;

        // Value
        let mut v = xs.clone();

        if let (Ok(vd), Ok(n)) = (&value_down_conv, &value_norm) {
            v = v.apply(vd)?;
            v = v.apply_t(n, false)?;
        }

        let v = v.apply(&value_proj)?;
        let v = reshape_kv(&v)?;

        let attn = q.broadcast_matmul(&(k.transpose(D::Minus2, D::Minus1)?))?;
        let attn = softmax(&attn, D::Minus1)?;
        let o = attn.broadcast_matmul(&v)?;

        let o = reshape_output(&o, heads, h, w)?;

        let mut xs = o.apply(&output_proj)?;

        // Layer scale

        if let Ok(g) = &gamma {
            xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
        };

        if skip_connection {
            xs = (xs + residual)?;
        }
        Ok(xs)
    }))
}