def training_substep()

in threestudio/systems/dreamcraft3d.py [0:0]


    def training_substep(self, batch, batch_idx, guidance: str, render_type="rgb"):
        """
        Args:
            guidance: one of "ref" (reference image supervision), "guidance"
        """

        gt_mask = batch["mask"]
        gt_rgb = batch["rgb"]
        gt_depth = batch["ref_depth"]
        gt_normal = batch["ref_normal"]
        mvp_mtx_ref = batch["mvp_mtx"]
        c2w_ref = batch["c2w4x4"]

        if guidance == "guidance":
            batch = batch["random_camera"]

        # Support rendering visibility mask
        batch["mvp_mtx_ref"] = mvp_mtx_ref
        batch["c2w_ref"] = c2w_ref

        out = self(batch)
        loss_prefix = f"loss_{guidance}_"

        loss_terms = {}

        def set_loss(name, value):
            loss_terms[f"{loss_prefix}{name}"] = value

        guidance_eval = (
            guidance == "guidance"
            and self.cfg.freq.guidance_eval > 0
            and self.true_global_step % self.cfg.freq.guidance_eval == 0
        )

        prompt_utils = self.prompt_processor()

        if guidance == "ref":
            if render_type == "rgb":
                # color loss. Use l2 loss in coarse and geometry satge; use l1 loss in texture stage.
                if self.C(self.cfg.loss.lambda_rgb) > 0:
                    gt_rgb = gt_rgb * gt_mask.float() + out["comp_rgb_bg"] * (
                        1 - gt_mask.float()
                    )
                    pred_rgb = out["comp_rgb"]
                    if self.cfg.stage in ["coarse", "geometry"]:
                        set_loss("rgb", F.mse_loss(gt_rgb, pred_rgb))
                    else:
                        if self.cfg.stage == "texture":
                            grow_mask = F.max_pool2d(1 - gt_mask.float().permute(0, 3, 1, 2), (9, 9), 1, 4)
                            grow_mask = (1 - grow_mask).permute(0, 2, 3, 1)
                            set_loss("rgb", F.l1_loss(gt_rgb*grow_mask, pred_rgb*grow_mask))
                        else:
                            set_loss("rgb", F.l1_loss(gt_rgb, pred_rgb))

                # mask loss
                if self.C(self.cfg.loss.lambda_mask) > 0:
                    set_loss("mask", F.mse_loss(gt_mask.float(), out["opacity"]))

                # mask binary cross loss
                if self.C(self.cfg.loss.lambda_mask_binary) > 0:
                    set_loss("mask_binary", F.binary_cross_entropy(
                    out["opacity"].clamp(1.0e-5, 1.0 - 1.0e-5),
                    batch["mask"].float(),))

                # depth loss
                if self.C(self.cfg.loss.lambda_depth) > 0:
                    valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)].unsqueeze(1)
                    valid_pred_depth = out["depth"][gt_mask].unsqueeze(1)
                    with torch.no_grad():
                        A = torch.cat(
                            [valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1
                        )  # [B, 2]
                        X = torch.linalg.lstsq(A, valid_pred_depth).solution  # [2, 1]
                        valid_gt_depth = A @ X  # [B, 1]
                    set_loss("depth", F.mse_loss(valid_gt_depth, valid_pred_depth))

                # relative depth loss
                if self.C(self.cfg.loss.lambda_depth_rel) > 0:
                    valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)]  # [B,]
                    valid_pred_depth = out["depth"][gt_mask]  # [B,]
                    set_loss(
                        "depth_rel", 1 - self.pearson(valid_pred_depth, valid_gt_depth)
                    )

            # normal loss
            if self.C(self.cfg.loss.lambda_normal) > 0:
                valid_gt_normal = (
                    1 - 2 * gt_normal[gt_mask.squeeze(-1)]
                )  # [B, 3]
                # FIXME: reverse x axis
                pred_normal = out["comp_normal_viewspace"]
                pred_normal[..., 0] = 1 - pred_normal[..., 0]
                valid_pred_normal = (
                    2 * pred_normal[gt_mask.squeeze(-1)] - 1
                )  # [B, 3]
                set_loss(
                    "normal",
                    1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean(),
                )

        elif guidance == "guidance" and self.true_global_step > self.cfg.freq.no_diff_steps:
            if self.cfg.stage == "geometry" and render_type == "normal":
                guidance_inp = out["comp_normal"]
            else:
                guidance_inp = out["comp_rgb"]
            guidance_out = self.guidance(
                guidance_inp,
                prompt_utils,
                **batch,
                rgb_as_latents=False,
                guidance_eval=guidance_eval,
                mask=out["mask"] if "mask" in out else None,
            )
            for name, value in guidance_out.items():
                self.log(f"train/{name}", value)
                if name.startswith("loss_"):
                    set_loss(name.split("_")[-1], value)

            if self.guidance_3d is not None:

                # FIXME: use mixed camera config
                if not self.cfg.use_mixed_camera_config or get_rank() % 2 == 0:
                    guidance_3d_out = self.guidance_3d(
                        out["comp_rgb"],
                        **batch,
                        rgb_as_latents=False,
                        guidance_eval=guidance_eval,
                    )
                    for name, value in guidance_3d_out.items():
                        if not (isinstance(value, torch.Tensor) and len(value.shape) > 0):
                            self.log(f"train/{name}_3d", value)
                        if name.startswith("loss_"):
                           set_loss("3d_"+name.split("_")[-1], value)
                    # set_loss("3d_sd", guidance_out["loss_sd"])

        # Regularization
        if self.C(self.cfg.loss.lambda_normal_smooth) > 0:
            if "comp_normal" not in out:
                raise ValueError(
                    "comp_normal is required for 2D normal smooth loss, no comp_normal is found in the output."
                )
            normal = out["comp_normal"]
            set_loss(
                "normal_smooth",
                (normal[:, 1:, :, :] - normal[:, :-1, :, :]).square().mean()
                + (normal[:, :, 1:, :] - normal[:, :, :-1, :]).square().mean(),
            )

        if self.C(self.cfg.loss.lambda_3d_normal_smooth) > 0:
            if "normal" not in out:
                raise ValueError(
                    "Normal is required for normal smooth loss, no normal is found in the output."
                )
            if "normal_perturb" not in out:
                raise ValueError(
                    "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output."
                )
            normals = out["normal"]
            normals_perturb = out["normal_perturb"]
            set_loss("3d_normal_smooth", (normals - normals_perturb).abs().mean())

        if self.cfg.stage == "coarse":
            if self.C(self.cfg.loss.lambda_orient) > 0:
                if "normal" not in out:
                    raise ValueError(
                        "Normal is required for orientation loss, no normal is found in the output."
                    )
                set_loss(
                    "orient",
                    (
                        out["weights"].detach()
                        * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2
                    ).sum()
                    / (out["opacity"] > 0).sum(),
                )

            if guidance != "ref" and self.C(self.cfg.loss.lambda_sparsity) > 0:
                set_loss("sparsity", (out["opacity"] ** 2 + 0.01).sqrt().mean())

            if self.C(self.cfg.loss.lambda_opaque) > 0:
                opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
                set_loss(
                    "opaque", binary_cross_entropy(opacity_clamped, opacity_clamped)
                )

            if "lambda_eikonal" in self.cfg.loss and self.C(self.cfg.loss.lambda_eikonal) > 0:
                if "sdf_grad" not in out:
                    raise ValueError(
                        "SDF grad is required for eikonal loss, no normal is found in the output."
                    )
                set_loss(
                    "eikonal", (
                        (torch.linalg.norm(out["sdf_grad"], ord=2, dim=-1) - 1.0) ** 2
                    ).mean()
                )
            
            if "lambda_z_variance"in self.cfg.loss and self.C(self.cfg.loss.lambda_z_variance) > 0:
                # z variance loss proposed in HiFA: http://arxiv.org/abs/2305.18766
                # helps reduce floaters and produce solid geometry
                loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean()
                set_loss("z_variance", loss_z_variance)

        elif self.cfg.stage == "geometry":
            if self.C(self.cfg.loss.lambda_normal_consistency) > 0:
                set_loss("normal_consistency", out["mesh"].normal_consistency())
            if self.C(self.cfg.loss.lambda_laplacian_smoothness) > 0:
                set_loss("laplacian_smoothness", out["mesh"].laplacian())
        elif self.cfg.stage == "texture":
            if self.C(self.cfg.loss.lambda_reg) > 0 and guidance == "guidance" and self.true_global_step % 5 == 0:
            
                rgb = out["comp_rgb"]
                rgb = F.interpolate(rgb.permute(0, 3, 1, 2), (512, 512), mode='bilinear').permute(0, 2, 3, 1)
                control_prompt_utils = self.control_prompt_processor()
                with torch.no_grad():
                    control_dict = self.control_guidance(
                        rgb=rgb,
                        cond_rgb=rgb,
                        prompt_utils=control_prompt_utils,
                        mask=out["mask"] if "mask" in out else None,
                    )

                    edit_images = control_dict["edit_images"]
                    temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8)
                    cv2.imwrite(".threestudio_cache/control_debug.jpg", temp[:, :, ::-1])

                loss_reg = (rgb.shape[1] // 8) * (rgb.shape[2] // 8) * self.perceptual_loss(edit_images.permute(0, 3, 1, 2), rgb.permute(0, 3, 1, 2)).mean()
                set_loss("reg", loss_reg)
        else:
            raise ValueError(f"Unknown stage {self.cfg.stage}")

        loss = 0.0
        for name, value in loss_terms.items():
            self.log(f"train/{name}", value)
            if name.startswith(loss_prefix):
                loss_weighted = value * self.C(
                    self.cfg.loss[name.replace(loss_prefix, "lambda_")]
                )
                self.log(f"train/{name}_w", loss_weighted)
                loss += loss_weighted

        for name, value in self.cfg.loss.items():
            self.log(f"train_params/{name}", self.C(value))

        self.log(f"train/loss_{guidance}", loss)

        if guidance_eval:
            self.guidance_evaluation_save(
                out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]],
                guidance_out["eval"],
            )

        return {"loss": loss}