fn mobileone_block()

in candle-transformers/src/models/fastvit.rs [191:293]


fn mobileone_block(
    in_channels: usize,
    out_channels: usize,
    kernel: usize,
    stride: usize,
    group_size: usize,
    use_act: bool,
    vb: VarBuilder,
) -> Result<Func<'static>> {
    let groups = if group_size == 0 {
        1
    } else {
        in_channels / group_size
    };

    let padding = kernel / 2;
    let conv2d_cfg = Conv2dConfig {
        stride,
        groups,
        padding,
        ..Default::default()
    };

    let mut w = Tensor::zeros(
        (out_channels, in_channels / groups, kernel, kernel),
        DType::F32,
        vb.device(),
    )?;
    let dim = out_channels;

    let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;

    let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.0.bn"));
    let conv_kxk = conv2d_no_bias(
        in_channels,
        out_channels,
        kernel,
        conv2d_cfg,
        vb.pp("conv_kxk.0.conv"),
    );

    if let (Ok(conv), Ok(bn)) = (conv_kxk, conv_kxk_bn) {
        let (wk, bk) = fuse_conv_bn(conv.weight(), bn)?;
        w = (w + wk)?;
        b = (b + bk)?;
    };

    let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"));
    let conv_scale = conv2d_no_bias(
        in_channels,
        out_channels,
        1,
        conv2d_cfg,
        vb.pp("conv_scale.conv"),
    );

    if let (Ok(conv), Ok(bn)) = (conv_scale, conv_scale_bn) {
        let (ws, bs) = fuse_conv_bn(conv.weight(), bn)?;
        // pad to 3x3
        let ws = ws
            .pad_with_zeros(D::Minus1, 1, 1)?
            .pad_with_zeros(D::Minus2, 1, 1)?;

        w = (w + ws)?;
        b = (b + bs)?;
    };

    let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("se"));

    // read and reparameterize the identity bn into wi and bi
    let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"));

    if let Ok(id_bn) = identity_bn {
        let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
        let id = in_channels / groups;
        // See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
        for i in 0..in_channels {
            if kernel > 1 {
                weights[i * kernel * kernel + 4] = 1.0;
            } else {
                weights[i * (id + 1)] = 1.0;
            }
        }

        let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
        let (wi, bi) = fuse_conv_bn(weights, id_bn)?;

        w = (w + wi)?;
        b = (b + bi)?;
    };
    let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);

    Ok(Func::new(move |xs| {
        let mut xs = xs.apply(&reparam_conv)?;
        if let Ok(f) = &se {
            xs = xs.apply(f)?;
        }
        if use_act {
            xs = xs.gelu_erf()?;
        };
        Ok(xs)
    }))
}