fn mobilenetv4_blocks()

in candle-transformers/src/models/mobilenetv4.rs [648:742]


fn mobilenetv4_blocks(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
    let mut in_channels = cfg.stem_dim;
    let mut blocks = Vec::new();

    for stage in 0..5 {
        let nblocks = cfg.stages[stage].len();

        for block in 0..nblocks {
            match cfg.stages[stage][block] {
                BlockType::Convolutional {
                    out_channels,
                    kernel,
                    stride,
                } => {
                    blocks.push(conv_block(
                        cfg,
                        in_channels,
                        out_channels,
                        kernel,
                        stride,
                        vb.pp(format!("{stage}.{block}")),
                    )?);
                    in_channels = out_channels;
                }

                BlockType::EdgeResidual {
                    out_channels,
                    kernel,
                    stride,
                    expand,
                } => {
                    blocks.push(edge_residual_block(
                        cfg,
                        in_channels,
                        out_channels,
                        kernel,
                        stride,
                        expand,
                        vb.pp(format!("{stage}.{block}")),
                    )?);
                    in_channels = out_channels;
                }

                BlockType::UniversalBottleneck {
                    out_channels,
                    start_kernel,
                    mid_kernel,
                    stride,
                    expand,
                } => {
                    blocks.push(universal_inverted_bottleneck_block(
                        cfg,
                        in_channels,
                        out_channels,
                        expand,
                        start_kernel,
                        mid_kernel,
                        stride,
                        vb.pp(format!("{stage}.{block}")),
                    )?);
                    in_channels = out_channels;
                }

                BlockType::Attention {
                    out_channels,
                    heads,
                    kernel,
                    stride,
                    kv_dim,
                    kv_stride,
                } => {
                    blocks.push(mqa_block(
                        in_channels,
                        out_channels,
                        heads,
                        kernel,
                        stride,
                        kv_dim,
                        kv_stride,
                        vb.pp(format!("{stage}.{block}")),
                    )?);
                    in_channels = out_channels;
                }
            }
        }
    }

    Ok(Func::new(move |xs| {
        let mut xs = xs.clone();
        for block in blocks.iter() {
            xs = xs.apply(block)?
        }
        Ok(xs)
    }))
}