in candle-transformers/src/models/mobilenetv4.rs [499:626]
fn mqa_block(
in_channels: usize,
out_channels: usize,
heads: usize,
kernel: usize,
stride: usize,
kv_dim: usize,
kv_stride: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let down_conv2d_cfg = Conv2dConfig {
stride: kv_stride,
padding: kernel / 2,
groups: in_channels,
..Default::default()
};
let proj_conv2d_cfg = Conv2dConfig {
stride,
..Default::default()
};
let skip_connection = (in_channels == out_channels) && (stride == 1);
let gamma = vb.get(out_channels, "layer_scale.gamma");
let norm = batch_norm(out_channels, 1e-5, vb.pp("norm"))?;
let scale = (kv_dim as f64).powf(-0.5);
let vb = vb.pp("attn");
let query_proj = conv2d_no_bias(
out_channels,
kv_dim * heads,
1,
proj_conv2d_cfg,
vb.pp("query.proj"),
)?;
let key_down_conv = conv2d_no_bias(
in_channels,
out_channels,
kernel,
down_conv2d_cfg,
vb.pp("key.down_conv"),
);
let key_norm = batch_norm(out_channels, 1e-5, vb.pp("key.norm"));
let key_proj = conv2d_no_bias(out_channels, kv_dim, 1, proj_conv2d_cfg, vb.pp("key.proj"))?;
let value_down_conv = conv2d_no_bias(
in_channels,
out_channels,
kernel,
down_conv2d_cfg,
vb.pp("value.down_conv"),
);
let value_norm = batch_norm(out_channels, 1e-5, vb.pp("value.norm"));
let value_proj = conv2d_no_bias(
out_channels,
kv_dim,
1,
proj_conv2d_cfg,
vb.pp("value.proj"),
)?;
let output_proj = conv2d_no_bias(
kv_dim * heads,
out_channels,
1,
proj_conv2d_cfg,
vb.pp("output.proj"),
)?;
Ok(Func::new(move |xs| {
let (_, _, h, w) = xs.dims4()?;
let residual = xs.clone();
let xs = xs.apply_t(&norm, false)?;
// Query
let q = xs.apply(&query_proj)?;
let q = reshape_query(&q, heads, kv_dim)?;
let q = (q * scale)?;
// Keys
let mut k = xs.clone();
if let (Ok(kd), Ok(n)) = (&key_down_conv, &key_norm) {
k = k.apply(kd)?.apply_t(n, false)?;
}
let k = k.apply(&key_proj)?;
let k = reshape_kv(&k)?;
// Value
let mut v = xs.clone();
if let (Ok(vd), Ok(n)) = (&value_down_conv, &value_norm) {
v = v.apply(vd)?;
v = v.apply_t(n, false)?;
}
let v = v.apply(&value_proj)?;
let v = reshape_kv(&v)?;
let attn = q.broadcast_matmul(&(k.transpose(D::Minus2, D::Minus1)?))?;
let attn = softmax(&attn, D::Minus1)?;
let o = attn.broadcast_matmul(&v)?;
let o = reshape_output(&o, heads, h, w)?;
let mut xs = o.apply(&output_proj)?;
// Layer scale
if let Ok(g) = &gamma {
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
};
if skip_connection {
xs = (xs + residual)?;
}
Ok(xs)
}))
}