fn new()

in candle-transformers/src/models/stable_diffusion/vae.rs [177:257]


    fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        out_channels: usize,
        config: DecoderConfig,
    ) -> Result<Self> {
        let n_block_out_channels = config.block_out_channels.len();
        let last_block_out_channels = *config.block_out_channels.last().unwrap();
        let conv_cfg = nn::Conv2dConfig {
            padding: 1,
            ..Default::default()
        };
        let conv_in = nn::conv2d(
            in_channels,
            last_block_out_channels,
            3,
            conv_cfg,
            vs.pp("conv_in"),
        )?;
        let mid_cfg = UNetMidBlock2DConfig {
            resnet_eps: 1e-6,
            output_scale_factor: 1.,
            attn_num_head_channels: None,
            resnet_groups: Some(config.norm_num_groups),
            ..Default::default()
        };
        let mid_block =
            UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
        let mut up_blocks = vec![];
        let vs_up_blocks = vs.pp("up_blocks");
        let reversed_block_out_channels: Vec<_> =
            config.block_out_channels.iter().copied().rev().collect();
        for index in 0..n_block_out_channels {
            let out_channels = reversed_block_out_channels[index];
            let in_channels = if index > 0 {
                reversed_block_out_channels[index - 1]
            } else {
                reversed_block_out_channels[0]
            };
            let is_final = index + 1 == n_block_out_channels;
            let cfg = UpDecoderBlock2DConfig {
                num_layers: config.layers_per_block + 1,
                resnet_eps: 1e-6,
                resnet_groups: config.norm_num_groups,
                add_upsample: !is_final,
                ..Default::default()
            };
            let up_block = UpDecoderBlock2D::new(
                vs_up_blocks.pp(index.to_string()),
                in_channels,
                out_channels,
                cfg,
            )?;
            up_blocks.push(up_block)
        }
        let conv_norm_out = nn::group_norm(
            config.norm_num_groups,
            config.block_out_channels[0],
            1e-6,
            vs.pp("conv_norm_out"),
        )?;
        let conv_cfg = nn::Conv2dConfig {
            padding: 1,
            ..Default::default()
        };
        let conv_out = nn::conv2d(
            config.block_out_channels[0],
            out_channels,
            3,
            conv_cfg,
            vs.pp("conv_out"),
        )?;
        Ok(Self {
            conv_in,
            up_blocks,
            mid_block,
            conv_norm_out,
            conv_out,
            config,
        })
    }