candle-transformers/src/models/mmdit/model.rs (213 lines of code) (raw):
// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206),
// as well as the MMDiT-X variant introduced for Stable Diffusion 3.5-medium (https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
// This follows the implementation of the MMDiT model in the ComfyUI repository.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
// with MMDiT-X support following the Stability-AI/sd3.5 repository.
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py#L1
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
use super::blocks::{
ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock,
};
use super::embedding::{
PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
};
#[derive(Debug, Clone)]
pub struct Config {
pub patch_size: usize,
pub in_channels: usize,
pub out_channels: usize,
pub depth: usize,
pub head_size: usize,
pub adm_in_channels: usize,
pub pos_embed_max_size: usize,
pub context_embed_size: usize,
pub frequency_embedding_size: usize,
}
impl Config {
pub fn sd3_medium() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 24,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 192,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
pub fn sd3_5_medium() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 24,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 384,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
pub fn sd3_5_large() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 38,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 192,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
}
pub struct MMDiT {
core: MMDiTCore,
patch_embedder: PatchEmbedder,
pos_embedder: PositionEmbedder,
timestep_embedder: TimestepEmbedder,
vector_embedder: VectorEmbedder,
context_embedder: nn::Linear,
unpatchifier: Unpatchifier,
}
impl MMDiT {
pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> {
let hidden_size = cfg.head_size * cfg.depth;
let core = MMDiTCore::new(
cfg.depth,
hidden_size,
cfg.depth,
cfg.patch_size,
cfg.out_channels,
use_flash_attn,
vb.clone(),
)?;
let patch_embedder = PatchEmbedder::new(
cfg.patch_size,
cfg.in_channels,
hidden_size,
vb.pp("x_embedder"),
)?;
let pos_embedder = PositionEmbedder::new(
hidden_size,
cfg.patch_size,
cfg.pos_embed_max_size,
vb.clone(),
)?;
let timestep_embedder = TimestepEmbedder::new(
hidden_size,
cfg.frequency_embedding_size,
vb.pp("t_embedder"),
)?;
let vector_embedder =
VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
let context_embedder = nn::linear(
cfg.context_embed_size,
hidden_size,
vb.pp("context_embedder"),
)?;
let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?;
Ok(Self {
core,
patch_embedder,
pos_embedder,
timestep_embedder,
vector_embedder,
context_embedder,
unpatchifier,
})
}
pub fn forward(
&self,
x: &Tensor,
t: &Tensor,
y: &Tensor,
context: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
// Following the convention of the ComfyUI implementation.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
//
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// t: (N,) tensor of diffusion timesteps
// y: (N,) tensor of class labels
let h = x.dim(D::Minus2)?;
let w = x.dim(D::Minus1)?;
let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
let x = self
.patch_embedder
.forward(x)?
.broadcast_add(&cropped_pos_embed)?;
let c = self.timestep_embedder.forward(t)?;
let y = self.vector_embedder.forward(y)?;
let c = (c + y)?;
let context = self.context_embedder.forward(context)?;
let x = self.core.forward(&context, &x, &c, skip_layers)?;
let x = self.unpatchifier.unpatchify(&x, h, w)?;
x.narrow(2, 0, h)?.narrow(3, 0, w)
}
}
pub struct MMDiTCore {
joint_blocks: Vec<Box<dyn JointBlock>>,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
final_layer: FinalLayer,
}
impl MMDiTCore {
pub fn new(
depth: usize,
hidden_size: usize,
num_heads: usize,
patch_size: usize,
out_channels: usize,
use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let mut joint_blocks = Vec::with_capacity(depth - 1);
for i in 0..depth - 1 {
let joint_block_vb_pp = format!("joint_blocks.{i}");
let joint_block: Box<dyn JointBlock> =
if vb.contains_tensor(&format!("{joint_block_vb_pp}.x_block.attn2.qkv.weight")) {
Box::new(MMDiTXJointBlock::new(
hidden_size,
num_heads,
use_flash_attn,
vb.pp(&joint_block_vb_pp),
)?)
} else {
Box::new(MMDiTJointBlock::new(
hidden_size,
num_heads,
use_flash_attn,
vb.pp(&joint_block_vb_pp),
)?)
};
joint_blocks.push(joint_block);
}
Ok(Self {
joint_blocks,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
hidden_size,
num_heads,
use_flash_attn,
vb.pp(format!("joint_blocks.{}", depth - 1)),
)?,
final_layer: FinalLayer::new(
hidden_size,
patch_size,
out_channels,
vb.pp("final_layer"),
)?,
})
}
pub fn forward(
&self,
context: &Tensor,
x: &Tensor,
c: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
let (mut context, mut x) = (context.clone(), x.clone());
for (i, joint_block) in self.joint_blocks.iter().enumerate() {
if let Some(skip_layers) = &skip_layers {
if skip_layers.contains(&i) {
continue;
}
}
(context, x) = joint_block.forward(&context, &x, c)?;
}
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
self.final_layer.forward(&x, c)
}
}