in sat/dit_video_concat.py [0:0]
def _build_modules(self, module_configs):
model_channels = self.hidden_size
# time_embed_dim = model_channels * 4
time_embed_dim = self.time_embed_dim
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = nn.Sequential(
Timestep(model_channels),
nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
),
)
elif self.num_classes == "sequential":
assert self.adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(self.adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
if self.zero_init_y_embed:
nn.init.constant_(self.label_emb[0][2].weight, 0)
nn.init.constant_(self.label_emb[0][2].bias, 0)
else:
raise ValueError()
pos_embed_config = module_configs["pos_embed_config"]
self.add_mixin(
"pos_embed",
instantiate_from_config(
pos_embed_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size=self.hidden_size,
),
reinit=True,
)
patch_embed_config = module_configs["patch_embed_config"]
self.add_mixin(
"patch_embed",
instantiate_from_config(
patch_embed_config,
patch_size=self.patch_size,
hidden_size=self.hidden_size,
in_channels=self.in_channels,
),
reinit=True,
)
if self.input_time == "adaln":
adaln_layer_config = module_configs["adaln_layer_config"]
self.add_mixin(
"adaln_layer",
instantiate_from_config(
adaln_layer_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size_head=self.hidden_size // self.num_attention_heads,
time_embed_dim=self.time_embed_dim,
elementwise_affine=self.elementwise_affine,
),
)
else:
raise NotImplementedError
final_layer_config = module_configs["final_layer_config"]
self.add_mixin(
"final_layer",
instantiate_from_config(
final_layer_config,
hidden_size=self.hidden_size,
patch_size=self.patch_size,
out_channels=self.out_channels,
time_embed_dim=self.time_embed_dim,
latent_width=self.latent_width,
latent_height=self.latent_height,
elementwise_affine=self.elementwise_affine,
),
reinit=True,
)
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
return