in candle-transformers/src/models/stable_diffusion/vae.rs [177:257]
fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: DecoderConfig,
) -> Result<Self> {
let n_block_out_channels = config.block_out_channels.len();
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_in = nn::conv2d(
in_channels,
last_block_out_channels,
3,
conv_cfg,
vs.pp("conv_in"),
)?;
let mid_cfg = UNetMidBlock2DConfig {
resnet_eps: 1e-6,
output_scale_factor: 1.,
attn_num_head_channels: None,
resnet_groups: Some(config.norm_num_groups),
..Default::default()
};
let mid_block =
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
let mut up_blocks = vec![];
let vs_up_blocks = vs.pp("up_blocks");
let reversed_block_out_channels: Vec<_> =
config.block_out_channels.iter().copied().rev().collect();
for index in 0..n_block_out_channels {
let out_channels = reversed_block_out_channels[index];
let in_channels = if index > 0 {
reversed_block_out_channels[index - 1]
} else {
reversed_block_out_channels[0]
};
let is_final = index + 1 == n_block_out_channels;
let cfg = UpDecoderBlock2DConfig {
num_layers: config.layers_per_block + 1,
resnet_eps: 1e-6,
resnet_groups: config.norm_num_groups,
add_upsample: !is_final,
..Default::default()
};
let up_block = UpDecoderBlock2D::new(
vs_up_blocks.pp(index.to_string()),
in_channels,
out_channels,
cfg,
)?;
up_blocks.push(up_block)
}
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
config.block_out_channels[0],
1e-6,
vs.pp("conv_norm_out"),
)?;
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_out = nn::conv2d(
config.block_out_channels[0],
out_channels,
3,
conv_cfg,
vs.pp("conv_out"),
)?;
Ok(Self {
conv_in,
up_blocks,
mid_block,
conv_norm_out,
conv_out,
config,
})
}