threestudio/models/guidance/stable_diffusion_bsd_guidance.py [392:686]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            if class_labels is None:
                with self.disable_unet_class_embedding(pipe.unet) as unet:
                    noise_pred = unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=text_embeddings.to(self.weights_dtype),
                        cross_attention_kwargs=cross_attention_kwargs,
                    ).sample
            else:
                noise_pred = pipe.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=text_embeddings.to(self.weights_dtype),
                    class_labels=class_labels,
                    cross_attention_kwargs=cross_attention_kwargs,
                ).sample

            noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

            # compute the previous noisy sample x_t -> x_t-1
            latents = sample_scheduler.step(noise_pred, t, latents).prev_sample

        latents = 1 / pipe.vae.config.scaling_factor * latents
        images = pipe.vae.decode(latents).sample
        images = (images / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        images = images.permute(0, 2, 3, 1).float()

        return images

    def sample(
        self,
        prompt_utils: PromptProcessorOutput,
        elevation: Float[Tensor, "B"],
        azimuth: Float[Tensor, "B"],
        camera_distances: Float[Tensor, "B"],
        seed: int = 0,
        **kwargs,
    ) -> Float[Tensor, "N H W 3"]:
        # view-dependent text embeddings
        text_embeddings_vd = prompt_utils.get_text_embeddings(
            elevation,
            azimuth,
            camera_distances,
            view_dependent_prompting=self.cfg.view_dependent_prompting,
        )
        cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
        generator = torch.Generator(device=self.device).manual_seed(seed)

        return self._sample(
            pipe=self.pipe,
            sample_scheduler=self.scheduler_sample,
            text_embeddings=text_embeddings_vd,
            num_inference_steps=25,
            guidance_scale=self.cfg.guidance_scale,
            cross_attention_kwargs=cross_attention_kwargs,
            generator=generator,
        )

    def sample_img2img(
        self,
        rgb: Float[Tensor, "B H W C"],
        prompt_utils: PromptProcessorOutput,
        elevation: Float[Tensor, "B"],
        azimuth: Float[Tensor, "B"],
        camera_distances: Float[Tensor, "B"],
        seed: int = 0,
        mask = None,
        **kwargs,
    ) -> Float[Tensor, "N H W 3"]:

        rgb_BCHW = rgb.permute(0, 3, 1, 2)
        mask_BCHW = mask.permute(0, 3, 1, 2)
        latents = self.get_latents(rgb_BCHW, rgb_as_latents=False) # TODO: 有部分概率是du或者ref image

        # view-dependent text embeddings
        text_embeddings_vd = prompt_utils.get_text_embeddings(
            elevation,
            azimuth,
            camera_distances,
            view_dependent_prompting=self.cfg.view_dependent_prompting,
        )
        cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
        generator = torch.Generator(device=self.device).manual_seed(seed)

        # return self._sample(
        #     pipe=self.pipe,
        #     sample_scheduler=self.scheduler_sample,
        #     text_embeddings=text_embeddings_vd,
        #     num_inference_steps=25,
        #     guidance_scale=self.cfg.guidance_scale,
        #     cross_attention_kwargs=cross_attention_kwargs,
        #     generator=generator,
        #     latents_inp=latents
        # )

        return self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask_BCHW)

    def sample_lora(
        self,
        prompt_utils: PromptProcessorOutput,
        elevation: Float[Tensor, "B"],
        azimuth: Float[Tensor, "B"],
        camera_distances: Float[Tensor, "B"],
        mvp_mtx: Float[Tensor, "B 4 4"],
        c2w: Float[Tensor, "B 4 4"],
        seed: int = 0,
        **kwargs,
    ) -> Float[Tensor, "N H W 3"]:
        # input text embeddings, view-independent
        text_embeddings = prompt_utils.get_text_embeddings(
            elevation, azimuth, camera_distances, view_dependent_prompting=False
        )

        if self.cfg.camera_condition_type == "extrinsics":
            camera_condition = c2w
        elif self.cfg.camera_condition_type == "mvp":
            camera_condition = mvp_mtx
        else:
            raise ValueError(
                f"Unknown camera_condition_type {self.cfg.camera_condition_type}"
            )

        B = elevation.shape[0]
        camera_condition_cfg = torch.cat(
            [
                camera_condition.view(B, -1),
                torch.zeros_like(camera_condition.view(B, -1)),
            ],
            dim=0,
        )

        generator = torch.Generator(device=self.device).manual_seed(seed)
        return self._sample(
            sample_scheduler=self.scheduler_lora_sample,
            pipe=self.pipe_lora,
            text_embeddings=text_embeddings,
            num_inference_steps=25,
            guidance_scale=self.cfg.guidance_scale_lora,
            class_labels=camera_condition_cfg,
            cross_attention_kwargs={"scale": 1.0},
            generator=generator,
        )

    @torch.cuda.amp.autocast(enabled=False)
    def forward_unet(
        self,
        unet: UNet2DConditionModel,
        latents: Float[Tensor, "..."],
        t: Float[Tensor, "..."],
        encoder_hidden_states: Float[Tensor, "..."],
        class_labels: Optional[Float[Tensor, "B 16"]] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Float[Tensor, "..."]:
        input_dtype = latents.dtype
        return unet(
            latents.to(self.weights_dtype),
            t.to(self.weights_dtype),
            encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
            class_labels=class_labels,
            cross_attention_kwargs=cross_attention_kwargs,
        ).sample.to(input_dtype)

    @torch.cuda.amp.autocast(enabled=False)
    def encode_images(
        self, imgs: Float[Tensor, "B 3 512 512"]
    ) -> Float[Tensor, "B 4 64 64"]:
        input_dtype = imgs.dtype
        imgs = imgs * 2.0 - 1.0
        posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
        latents = posterior.sample() * self.vae.config.scaling_factor
        return latents.to(input_dtype)

    @torch.cuda.amp.autocast(enabled=False)
    def decode_latents(
        self,
        latents: Float[Tensor, "B 4 H W"],
        latent_height: int = 64,
        latent_width: int = 64,
    ) -> Float[Tensor, "B 3 512 512"]:
        input_dtype = latents.dtype
        latents = F.interpolate(
            latents, (latent_height, latent_width), mode="bilinear", align_corners=False
        )
        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents.to(self.weights_dtype)).sample
        image = (image * 0.5 + 0.5).clamp(0, 1)
        return image.to(input_dtype)

    @contextmanager
    def disable_unet_class_embedding(self, unet: UNet2DConditionModel):
        class_embedding = unet.class_embedding
        try:
            unet.class_embedding = None
            yield unet
        finally:
            unet.class_embedding = class_embedding

    def compute_grad_du(
        self,
        latents: Float[Tensor, "B 4 64 64"],
        rgb_BCHW_512: Float[Tensor, "B 3 512 512"],
        text_embeddings: Float[Tensor, "BB 77 768"],
        mask = None,
        **kwargs,
    ):
        batch_size, _, _, _ = latents.shape
        rgb_BCHW_512 = F.interpolate(rgb_BCHW_512, (512, 512), mode="bilinear")
        assert batch_size == 1
        need_diffusion = (
            self.global_step % self.cfg.per_du_step == 0
            and self.global_step > self.cfg.start_du_step
        )
        guidance_out = {}

        if need_diffusion:
            t = torch.randint(
                self.min_step,
                self.max_step,
                [1],
                dtype=torch.long,
                device=self.device,
            )
            self.scheduler.config.num_train_timesteps = t.item()
            self.scheduler.set_timesteps(self.cfg.du_diffusion_steps)

            if mask is not None:
                mask = F.interpolate(mask, (64, 64), mode="bilinear", antialias=True)
            with torch.no_grad():
                # add noise
                noise = torch.randn_like(latents)
                latents = self.scheduler.add_noise(latents, noise, t)  # type: ignore
                for i, timestep in enumerate(self.scheduler.timesteps):
                    # predict the noise residual with unet, NO grad!
                    with torch.no_grad():
                        latent_model_input = torch.cat([latents] * 2)
                        with self.disable_unet_class_embedding(self.unet) as unet:
                            cross_attention_kwargs = (
                                {"scale": 0.0} if self.single_model else None
                            )
                            noise_pred = self.forward_unet(
                                unet,
                                latent_model_input,
                                timestep,
                                encoder_hidden_states=text_embeddings,
                                cross_attention_kwargs=cross_attention_kwargs,
                            )
                    # perform classifier-free guidance
                    noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
                        noise_pred_text - noise_pred_uncond
                    )
                    if mask is not None:
                        noise_pred = mask * noise_pred + (1 - mask) * noise
                    # get previous sample, continue loop
                    latents = self.scheduler.step(
                        noise_pred, timestep, latents
                    ).prev_sample
            edit_images = self.decode_latents(latents)
            edit_images = F.interpolate(
                edit_images, (512, 512), mode="bilinear"
            ).permute(0, 2, 3, 1)
            gt_rgb = edit_images
            # import cv2
            # import numpy as np
            # mask_temp = mask_BCHW_512.permute(0,2,3,1)
            # # edit_images = edit_images * mask_temp + torch.rand(3)[None, None, None].to(self.device).repeat(*edit_images.shape[:-1],1) * (1 - mask_temp)
            # temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8)
            # cv2.imwrite(f".threestudio_cache/pig_sd_noise_500/test_{kwargs.get('name', 'none')}.jpg", temp[:, :, ::-1])

            guidance_out.update(
                {
                    "loss_l1": torch.nn.functional.l1_loss(
                        rgb_BCHW_512, gt_rgb.permute(0, 3, 1, 2), reduction="sum"
                    ),
                    "loss_p": self.perceptual_loss(
                        rgb_BCHW_512.contiguous(),
                        gt_rgb.permute(0, 3, 1, 2).contiguous(),
                    ).sum(),
                    "edit_image": edit_images.detach()
                }
            )

        return guidance_out

    def compute_grad_vsd(
        self,
        latents: Float[Tensor, "B 4 64 64"],
        text_embeddings_vd: Float[Tensor, "BB 77 768"],
        text_embeddings: Float[Tensor, "BB 77 768"],
        camera_condition: Float[Tensor, "B 4 4"],
    ):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



threestudio/models/guidance/stable_diffusion_vsd_guidance.py [335:628]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            if class_labels is None:
                with self.disable_unet_class_embedding(pipe.unet) as unet:
                    noise_pred = unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=text_embeddings.to(self.weights_dtype),
                        cross_attention_kwargs=cross_attention_kwargs,
                    ).sample
            else:
                noise_pred = pipe.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=text_embeddings.to(self.weights_dtype),
                    class_labels=class_labels,
                    cross_attention_kwargs=cross_attention_kwargs,
                ).sample

            noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

            # compute the previous noisy sample x_t -> x_t-1
            latents = sample_scheduler.step(noise_pred, t, latents).prev_sample

        latents = 1 / pipe.vae.config.scaling_factor * latents
        images = pipe.vae.decode(latents).sample
        images = (images / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        images = images.permute(0, 2, 3, 1).float()
        return images

    def sample(
        self,
        prompt_utils: PromptProcessorOutput,
        elevation: Float[Tensor, "B"],
        azimuth: Float[Tensor, "B"],
        camera_distances: Float[Tensor, "B"],
        seed: int = 0,
        **kwargs,
    ) -> Float[Tensor, "N H W 3"]:
        # view-dependent text embeddings
        text_embeddings_vd = prompt_utils.get_text_embeddings(
            elevation,
            azimuth,
            camera_distances,
            view_dependent_prompting=self.cfg.view_dependent_prompting,
        )
        cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
        generator = torch.Generator(device=self.device).manual_seed(seed)

        return self._sample(
            pipe=self.pipe,
            sample_scheduler=self.scheduler_sample,
            text_embeddings=text_embeddings_vd,
            num_inference_steps=25,
            guidance_scale=self.cfg.guidance_scale,
            cross_attention_kwargs=cross_attention_kwargs,
            generator=generator,
        )

    def sample_img2img(
        self,
        rgb: Float[Tensor, "B H W C"],
        prompt_utils: PromptProcessorOutput,
        elevation: Float[Tensor, "B"],
        azimuth: Float[Tensor, "B"],
        camera_distances: Float[Tensor, "B"],
        seed: int = 0,
        mask = None,
        **kwargs,
    ) -> Float[Tensor, "N H W 3"]:

        rgb_BCHW = rgb.permute(0, 3, 1, 2)
        mask_BCHW = mask.permute(0, 3, 1, 2)
        latents = self.get_latents(rgb_BCHW, rgb_as_latents=False) # TODO: 有部分概率是du或者ref image
    
        # view-dependent text embeddings
        text_embeddings_vd = prompt_utils.get_text_embeddings(
            elevation,
            azimuth,
            camera_distances,
            view_dependent_prompting=self.cfg.view_dependent_prompting,
        )
        cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
        generator = torch.Generator(device=self.device).manual_seed(seed)

        # return self._sample(
        #     pipe=self.pipe,
        #     sample_scheduler=self.scheduler_sample,
        #     text_embeddings=text_embeddings_vd,
        #     num_inference_steps=25,
        #     guidance_scale=self.cfg.guidance_scale,
        #     cross_attention_kwargs=cross_attention_kwargs,
        #     generator=generator,
        #     latents_inp=latents
        # )

        return self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask_BCHW)

    def sample_lora(
        self,
        prompt_utils: PromptProcessorOutput,
        elevation: Float[Tensor, "B"],
        azimuth: Float[Tensor, "B"],
        camera_distances: Float[Tensor, "B"],
        mvp_mtx: Float[Tensor, "B 4 4"],
        c2w: Float[Tensor, "B 4 4"],
        seed: int = 0,
        **kwargs,
    ) -> Float[Tensor, "N H W 3"]:
        # input text embeddings, view-independent
        text_embeddings = prompt_utils.get_text_embeddings(
            elevation, azimuth, camera_distances, view_dependent_prompting=False
        )

        if self.cfg.camera_condition_type == "extrinsics":
            camera_condition = c2w
        elif self.cfg.camera_condition_type == "mvp":
            camera_condition = mvp_mtx
        else:
            raise ValueError(
                f"Unknown camera_condition_type {self.cfg.camera_condition_type}"
            )

        B = elevation.shape[0]
        camera_condition_cfg = torch.cat(
            [
                camera_condition.view(B, -1),
                torch.zeros_like(camera_condition.view(B, -1)),
            ],
            dim=0,
        )

        generator = torch.Generator(device=self.device).manual_seed(seed)
        return self._sample(
            sample_scheduler=self.scheduler_lora_sample,
            pipe=self.pipe_lora,
            text_embeddings=text_embeddings,
            num_inference_steps=25,
            guidance_scale=self.cfg.guidance_scale_lora,
            class_labels=camera_condition_cfg,
            cross_attention_kwargs={"scale": 1.0},
            generator=generator,
        )

    @torch.cuda.amp.autocast(enabled=False)
    def forward_unet(
        self,
        unet: UNet2DConditionModel,
        latents: Float[Tensor, "..."],
        t: Float[Tensor, "..."],
        encoder_hidden_states: Float[Tensor, "..."],
        class_labels: Optional[Float[Tensor, "B 16"]] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Float[Tensor, "..."]:
        input_dtype = latents.dtype
        return unet(
            latents.to(self.weights_dtype),
            t.to(self.weights_dtype),
            encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
            class_labels=class_labels,
            cross_attention_kwargs=cross_attention_kwargs,
        ).sample.to(input_dtype)

    @torch.cuda.amp.autocast(enabled=False)
    def encode_images(
        self, imgs: Float[Tensor, "B 3 512 512"]
    ) -> Float[Tensor, "B 4 64 64"]:
        input_dtype = imgs.dtype
        imgs = imgs * 2.0 - 1.0
        posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
        latents = posterior.sample() * self.vae.config.scaling_factor
        return latents.to(input_dtype)

    @torch.cuda.amp.autocast(enabled=False)
    def decode_latents(
        self,
        latents: Float[Tensor, "B 4 H W"],
        latent_height: int = 64,
        latent_width: int = 64,
    ) -> Float[Tensor, "B 3 512 512"]:
        input_dtype = latents.dtype
        latents = F.interpolate(
            latents, (latent_height, latent_width), mode="bilinear", align_corners=False
        )
        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents.to(self.weights_dtype)).sample
        image = (image * 0.5 + 0.5).clamp(0, 1)
        return image.to(input_dtype)

    @contextmanager
    def disable_unet_class_embedding(self, unet: UNet2DConditionModel):
        class_embedding = unet.class_embedding
        try:
            unet.class_embedding = None
            yield unet
        finally:
            unet.class_embedding = class_embedding

    def compute_grad_du(
        self,
        latents: Float[Tensor, "B 4 64 64"],
        rgb_BCHW_512: Float[Tensor, "B 3 512 512"],
        text_embeddings: Float[Tensor, "BB 77 768"],
        mask = None,
        **kwargs,
    ):
        batch_size, _, _, _ = latents.shape
        rgb_BCHW_512 = F.interpolate(rgb_BCHW_512, (512, 512), mode="bilinear")
        assert batch_size == 1
        need_diffusion = (
            self.global_step % self.cfg.per_du_step == 0
            and self.global_step > self.cfg.start_du_step
        )
        guidance_out = {}

        if need_diffusion:
            t = torch.randint(
                self.min_step,
                self.max_step,
                [1],
                dtype=torch.long,
                device=self.device,
            )
            self.scheduler.config.num_train_timesteps = t.item()
            self.scheduler.set_timesteps(self.cfg.du_diffusion_steps)

            if mask is not None:
                mask = F.interpolate(mask, (64, 64), mode="bilinear", antialias=True)
            with torch.no_grad():
                # add noise
                noise = torch.randn_like(latents)
                latents = self.scheduler.add_noise(latents, noise, t)  # type: ignore
                for i, timestep in enumerate(self.scheduler.timesteps):
                    # predict the noise residual with unet, NO grad!
                    with torch.no_grad():
                        latent_model_input = torch.cat([latents] * 2)
                        with self.disable_unet_class_embedding(self.unet) as unet:
                            cross_attention_kwargs = (
                                {"scale": 0.0} if self.single_model else None
                            )
                            noise_pred = self.forward_unet(
                                unet,
                                latent_model_input,
                                timestep,
                                encoder_hidden_states=text_embeddings,
                                cross_attention_kwargs=cross_attention_kwargs,
                            )
                    # perform classifier-free guidance
                    noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
                        noise_pred_text - noise_pred_uncond
                    )
                    if mask is not None:
                        noise_pred = mask * noise_pred + (1 - mask) * noise
                    # get previous sample, continue loop
                    latents = self.scheduler.step(
                        noise_pred, timestep, latents
                    ).prev_sample
            edit_images = self.decode_latents(latents)
            edit_images = F.interpolate(
                edit_images, (512, 512), mode="bilinear"
            ).permute(0, 2, 3, 1)
            gt_rgb = edit_images
            # import cv2
            # import numpy as np
            # mask_temp = mask_BCHW_512.permute(0,2,3,1)
            # # edit_images = edit_images * mask_temp + torch.rand(3)[None, None, None].to(self.device).repeat(*edit_images.shape[:-1],1) * (1 - mask_temp)
            # temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8)
            # cv2.imwrite(f".threestudio_cache/pig_sd_noise_500/test_{kwargs.get('name', 'none')}.jpg", temp[:, :, ::-1])

            guidance_out.update(
                {
                    "loss_l1": torch.nn.functional.l1_loss(
                        rgb_BCHW_512, gt_rgb.permute(0, 3, 1, 2), reduction="sum"
                    ),
                    "loss_p": self.perceptual_loss(
                        rgb_BCHW_512.contiguous(),
                        gt_rgb.permute(0, 3, 1, 2).contiguous(),
                    ).sum(),
                    "edit_image": edit_images.detach()
                }
            )

        return guidance_out

    def compute_grad_vsd(
        self,
        latents: Float[Tensor, "B 4 64 64"],
        text_embeddings_vd: Float[Tensor, "BB 77 768"],
        text_embeddings: Float[Tensor, "BB 77 768"],
        camera_condition: Float[Tensor, "B 4 4"],
    ):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



