fn mobileone_block()

in candle-transformers/src/models/mobileone.rs [124:220]


fn mobileone_block(
    has_identity: bool,
    k: usize,
    dim: usize,
    stride: usize,
    padding: usize,
    groups: usize,
    kernel: usize,
    in_channels: usize,
    out_channels: usize,
    vb: VarBuilder,
) -> Result<Func<'static>> {
    let conv2d_cfg = Conv2dConfig {
        stride,
        padding,
        groups,
        ..Default::default()
    };

    let mut w = Tensor::zeros(
        (out_channels, in_channels / groups, kernel, kernel),
        DType::F32,
        vb.device(),
    )?;
    let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;

    // k is the training-time overparameterization factor, larger than 1 only in the s0 variant
    for i in 0..k {
        let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
        let conv_kxk = conv2d_no_bias(
            in_channels,
            out_channels,
            kernel,
            conv2d_cfg,
            vb.pp(format!("conv_kxk.{i}.conv")),
        )?;
        let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
        w = (w + wk)?;
        b = (b + bk)?;
    }

    if kernel > 1 {
        let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
        let conv_scale = conv2d_no_bias(
            in_channels,
            out_channels,
            1,
            conv2d_cfg,
            vb.pp("conv_scale.conv"),
        )?;

        let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
        // resize to 3x3
        ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
        ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;

        w = (w + ws)?;
        b = (b + bs)?;
    }

    // Use SE blocks if present (last layers of the s4 variant)
    let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));

    // read and reparameterize the identity bn into wi and bi
    if has_identity {
        let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;

        let mut weights: Vec<f32> = vec![0.0; w.elem_count()];

        let id = in_channels / groups;
        // See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
        for i in 0..in_channels {
            if kernel > 1 {
                weights[i * kernel * kernel + 4] = 1.0;
            } else {
                weights[i * (id + 1)] = 1.0;
            }
        }

        let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
        let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;

        w = (w + wi)?;
        b = (b + bi)?;
    }

    let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);

    Ok(Func::new(move |xs| {
        let mut xs = xs.apply(&reparam_conv)?;
        if let Ok(f) = &se {
            xs = xs.apply(f)?;
        }
        xs = xs.relu()?;
        Ok(xs)
    }))
}