in sat/diffusion_video.py [0:0]
def disable_untrainable_params(self):
trainable_modules = self.trainable_modules
not_trainable_modules = self.not_trainable_modules
trainable_params = 0
print_rank0(f"{trainable_modules=}")
print_rank0(f"{not_trainable_modules=}")
names = []
for name, module in self.named_modules():
check = False
for tm in tuple(trainable_modules):
if tm == "all" or re.search(tm, name):
check = True
break
for tm in not_trainable_modules:
if re.search(tm, name):
check = False
break
if check:
names.append(name)
for m in module.parameters():
m.requires_grad_(True)
trainable_params += m.numel()
else:
for m in module.parameters():
m.requires_grad_(False)
# unfreeze adaLN_modulations gate_mlp and text_gate_mlp
if self.unfreeze_adaLN_gate:
print_rank0(
f"unfreezing adaLN_gate. shape: {self.model.diffusion_model.mixins.adaln_layer.adaLN_modulations[0][1].weight.shape}"
)
for m in self.model.diffusion_model.mixins.adaln_layer.adaLN_modulations:
hidden_size = self.model.diffusion_model.hidden_size
m[1].weight[5 * hidden_size : 6 * hidden_size, :].requires_grad_(True)
m[1].weight[11 * hidden_size : 12 * hidden_size, :].requires_grad_(True)
m[1].bias[5 * hidden_size : 6 * hidden_size].requires_grad_(True)
m[1].bias[11 * hidden_size : 12 * hidden_size].requires_grad_(True)
print_rank0("\n".join(["Trainable layers:"] + names))
print_rank0(f"{trainable_params} params have been unfrozen for training.")