in candle-transformers/src/models/mobilenetv4.rs [648:742]
fn mobilenetv4_blocks(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
let mut in_channels = cfg.stem_dim;
let mut blocks = Vec::new();
for stage in 0..5 {
let nblocks = cfg.stages[stage].len();
for block in 0..nblocks {
match cfg.stages[stage][block] {
BlockType::Convolutional {
out_channels,
kernel,
stride,
} => {
blocks.push(conv_block(
cfg,
in_channels,
out_channels,
kernel,
stride,
vb.pp(format!("{stage}.{block}")),
)?);
in_channels = out_channels;
}
BlockType::EdgeResidual {
out_channels,
kernel,
stride,
expand,
} => {
blocks.push(edge_residual_block(
cfg,
in_channels,
out_channels,
kernel,
stride,
expand,
vb.pp(format!("{stage}.{block}")),
)?);
in_channels = out_channels;
}
BlockType::UniversalBottleneck {
out_channels,
start_kernel,
mid_kernel,
stride,
expand,
} => {
blocks.push(universal_inverted_bottleneck_block(
cfg,
in_channels,
out_channels,
expand,
start_kernel,
mid_kernel,
stride,
vb.pp(format!("{stage}.{block}")),
)?);
in_channels = out_channels;
}
BlockType::Attention {
out_channels,
heads,
kernel,
stride,
kv_dim,
kv_stride,
} => {
blocks.push(mqa_block(
in_channels,
out_channels,
heads,
kernel,
stride,
kv_dim,
kv_stride,
vb.pp(format!("{stage}.{block}")),
)?);
in_channels = out_channels;
}
}
}
}
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
for block in blocks.iter() {
xs = xs.apply(block)?
}
Ok(xs)
}))
}