def configure()

in threestudio/models/guidance/stable_diffusion_bsd_guidance.py [0:0]


    def configure(self) -> None:
        threestudio.info(f"Loading Stable Diffusion ...")

        self.weights_dtype = (
            torch.float16 if self.cfg.half_precision_weights else torch.float32
        )

        pipe_kwargs = {
            "tokenizer": None,
            "safety_checker": None,
            "feature_extractor": None,
            "requires_safety_checker": False,
            "torch_dtype": self.weights_dtype,
            "cache_dir": self.cfg.cache_dir,
            "local_files_only": self.cfg.local_files_only
        }

        pipe_lora_kwargs = {
            "tokenizer": None,
            "safety_checker": None,
            "feature_extractor": None,
            "requires_safety_checker": False,
            "torch_dtype": self.weights_dtype,
            "cache_dir": self.cfg.cache_dir,
            "local_files_only": self.cfg.local_files_only
        }

        @dataclass
        class SubModules:
            pipe: StableDiffusionPipeline
            pipe_lora: StableDiffusionPipeline
            pipe_fix: StableDiffusionPipeline

        pipe = StableDiffusionPipeline.from_pretrained(
            self.cfg.pretrained_model_name_or_path,
            **pipe_kwargs,
        ).to(self.device)
        self.single_model = False
        pipe_lora = StableDiffusionPipeline.from_pretrained(
            self.cfg.pretrained_model_name_or_path_lora,
            **pipe_lora_kwargs,
        ).to(self.device)
        del pipe_lora.vae
        cleanup()
        pipe_lora.vae = pipe.vae

        pipe_fix = pipe

        self.submodules = SubModules(pipe=pipe, pipe_lora=pipe_lora, pipe_fix=pipe_fix)

        if self.cfg.enable_memory_efficient_attention:
            if parse_version(torch.__version__) >= parse_version("2"):
                threestudio.info(
                    "PyTorch2.0 uses memory efficient attention by default."
                )
            elif not is_xformers_available():
                threestudio.warn(
                    "xformers is not available, memory efficient attention is not enabled."
                )
            else:
                self.pipe.enable_xformers_memory_efficient_attention()
                self.pipe_lora.enable_xformers_memory_efficient_attention()

        if self.cfg.enable_sequential_cpu_offload:
            self.pipe.enable_sequential_cpu_offload()
            self.pipe_lora.enable_sequential_cpu_offload()

        if self.cfg.enable_attention_slicing:
            self.pipe.enable_attention_slicing(1)
            self.pipe_lora.enable_attention_slicing(1)

        if self.cfg.enable_channels_last_format:
            self.pipe.unet.to(memory_format=torch.channels_last)
            self.pipe_lora.unet.to(memory_format=torch.channels_last)

        del self.pipe.text_encoder
        if not self.single_model:
            del self.pipe_lora.text_encoder
        cleanup()

        for p in self.vae.parameters():
            p.requires_grad_(False)

        for p in self.vae_fix.parameters():
            p.requires_grad_(False)
        for p in self.unet_fix.parameters():
            p.requires_grad_(False)

        # FIXME: hard-coded dims
        self.camera_embedding = ToWeightsDType(
            TimestepEmbedding(16, 1280), self.weights_dtype
        ).to(self.device)
        # self.unet_lora.class_embedding = self.camera_embedding

        # set up LoRA layers
        # self.set_up_lora_layers(self.unet_lora)
        # self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to(
        #     self.device
        # )
        # self.lora_layers._load_state_dict_pre_hooks.clear()
        # self.lora_layers._state_dict_hooks.clear()

        # set up LoRA layers for pretrain
        # self.set_up_lora_layers(self.unet)
        # self.lora_layers_pretrain = AttnProcsLayers(self.unet.attn_processors).to(
        #     self.device
        # )
        # self.lora_layers_pretrain._load_state_dict_pre_hooks.clear()
        # self.lora_layers_pretrain._state_dict_hooks.clear()

        self.train_unet = UNet2DConditionModel.from_pretrained(
            self.cfg.pretrained_model_name_or_path, subfolder="unet", 
            torch_dtype=self.weights_dtype
        )
        self.train_unet.enable_xformers_memory_efficient_attention()
        self.train_unet.enable_gradient_checkpointing()

        self.train_unet_lora = UNet2DConditionModel.from_pretrained(
            self.cfg.pretrained_model_name_or_path_lora, subfolder="unet",
            torch_dtype=self.weights_dtype
        )
        self.train_unet_lora.enable_xformers_memory_efficient_attention()
        self.train_unet_lora.enable_gradient_checkpointing()
        
        for p in self.train_unet.parameters():
            p.requires_grad_(True)
        for p in self.train_unet_lora.parameters():
            p.requires_grad_(True)
        # for p in self.lora_layers.parameters():
        #     p.requires_grad_(False)

        self.scheduler = DDPMScheduler.from_pretrained( # DDPM
            self.cfg.pretrained_model_name_or_path,
            subfolder="scheduler",
            torch_dtype=self.weights_dtype,
            cache_dir=self.cfg.cache_dir,
            local_files_only=self.cfg.local_files_only,
        )

        self.scheduler_lora = DDPMScheduler.from_pretrained(
            self.cfg.pretrained_model_name_or_path_lora,
            subfolder="scheduler",
            torch_dtype=self.weights_dtype,
            cache_dir=self.cfg.cache_dir,
            local_files_only=self.cfg.local_files_only,
        )

        self.scheduler_sample = DPMSolverMultistepScheduler.from_config(
            self.pipe.scheduler.config
        )
        self.scheduler_lora_sample = DPMSolverMultistepScheduler.from_config(
            self.pipe_lora.scheduler.config
        )

        self.pipe.scheduler = self.scheduler
        self.pipe_lora.scheduler = self.scheduler_lora

        self.pipe_fix.scheduler = self.scheduler

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.set_min_max_steps()  # set to default value

        self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
            self.device
        )

        self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)

        self.grad_clip_val: Optional[float] = None

        if self.cfg.use_du:
            self.perceptual_loss = PerceptualLoss().eval().to(self.device)
            for p in self.perceptual_loss.parameters():
                p.requires_grad_(False)

        self.cache_frames = []

        threestudio.info(f"Loaded Stable Diffusion!")