in candle-transformers/src/models/fastvit.rs [191:293]
fn mobileone_block(
in_channels: usize,
out_channels: usize,
kernel: usize,
stride: usize,
group_size: usize,
use_act: bool,
vb: VarBuilder,
) -> Result<Func<'static>> {
let groups = if group_size == 0 {
1
} else {
in_channels / group_size
};
let padding = kernel / 2;
let conv2d_cfg = Conv2dConfig {
stride,
groups,
padding,
..Default::default()
};
let mut w = Tensor::zeros(
(out_channels, in_channels / groups, kernel, kernel),
DType::F32,
vb.device(),
)?;
let dim = out_channels;
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.0.bn"));
let conv_kxk = conv2d_no_bias(
in_channels,
out_channels,
kernel,
conv2d_cfg,
vb.pp("conv_kxk.0.conv"),
);
if let (Ok(conv), Ok(bn)) = (conv_kxk, conv_kxk_bn) {
let (wk, bk) = fuse_conv_bn(conv.weight(), bn)?;
w = (w + wk)?;
b = (b + bk)?;
};
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"));
let conv_scale = conv2d_no_bias(
in_channels,
out_channels,
1,
conv2d_cfg,
vb.pp("conv_scale.conv"),
);
if let (Ok(conv), Ok(bn)) = (conv_scale, conv_scale_bn) {
let (ws, bs) = fuse_conv_bn(conv.weight(), bn)?;
// pad to 3x3
let ws = ws
.pad_with_zeros(D::Minus1, 1, 1)?
.pad_with_zeros(D::Minus2, 1, 1)?;
w = (w + ws)?;
b = (b + bs)?;
};
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("se"));
// read and reparameterize the identity bn into wi and bi
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"));
if let Ok(id_bn) = identity_bn {
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
let id = in_channels / groups;
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
for i in 0..in_channels {
if kernel > 1 {
weights[i * kernel * kernel + 4] = 1.0;
} else {
weights[i * (id + 1)] = 1.0;
}
}
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
let (wi, bi) = fuse_conv_bn(weights, id_bn)?;
w = (w + wi)?;
b = (b + bi)?;
};
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
Ok(Func::new(move |xs| {
let mut xs = xs.apply(&reparam_conv)?;
if let Ok(f) = &se {
xs = xs.apply(f)?;
}
if use_act {
xs = xs.gelu_erf()?;
};
Ok(xs)
}))
}