def disable_untrainable_params()

in sat/diffusion_video.py [0:0]


    def disable_untrainable_params(self):
        trainable_modules = self.trainable_modules
        not_trainable_modules = self.not_trainable_modules
        trainable_params = 0
        print_rank0(f"{trainable_modules=}")
        print_rank0(f"{not_trainable_modules=}")
        names = []
        for name, module in self.named_modules():
            check = False
            for tm in tuple(trainable_modules):
                if tm == "all" or re.search(tm, name):
                    check = True
                    break
            for tm in not_trainable_modules:
                if re.search(tm, name):
                    check = False
                    break
            if check:
                names.append(name)
                for m in module.parameters():
                    m.requires_grad_(True)
                    trainable_params += m.numel()
            else:
                for m in module.parameters():
                    m.requires_grad_(False)

        # unfreeze adaLN_modulations gate_mlp and text_gate_mlp
        if self.unfreeze_adaLN_gate:
            print_rank0(
                f"unfreezing adaLN_gate. shape: {self.model.diffusion_model.mixins.adaln_layer.adaLN_modulations[0][1].weight.shape}"
            )
            for m in self.model.diffusion_model.mixins.adaln_layer.adaLN_modulations:
                hidden_size = self.model.diffusion_model.hidden_size
                m[1].weight[5 * hidden_size : 6 * hidden_size, :].requires_grad_(True)
                m[1].weight[11 * hidden_size : 12 * hidden_size, :].requires_grad_(True)
                m[1].bias[5 * hidden_size : 6 * hidden_size].requires_grad_(True)
                m[1].bias[11 * hidden_size : 12 * hidden_size].requires_grad_(True)

        print_rank0("\n".join(["Trainable layers:"] + names))
        print_rank0(f"{trainable_params} params have been unfrozen for training.")