in threestudio/systems/dreamcraft3d.py [0:0]
def training_substep(self, batch, batch_idx, guidance: str, render_type="rgb"):
"""
Args:
guidance: one of "ref" (reference image supervision), "guidance"
"""
gt_mask = batch["mask"]
gt_rgb = batch["rgb"]
gt_depth = batch["ref_depth"]
gt_normal = batch["ref_normal"]
mvp_mtx_ref = batch["mvp_mtx"]
c2w_ref = batch["c2w4x4"]
if guidance == "guidance":
batch = batch["random_camera"]
# Support rendering visibility mask
batch["mvp_mtx_ref"] = mvp_mtx_ref
batch["c2w_ref"] = c2w_ref
out = self(batch)
loss_prefix = f"loss_{guidance}_"
loss_terms = {}
def set_loss(name, value):
loss_terms[f"{loss_prefix}{name}"] = value
guidance_eval = (
guidance == "guidance"
and self.cfg.freq.guidance_eval > 0
and self.true_global_step % self.cfg.freq.guidance_eval == 0
)
prompt_utils = self.prompt_processor()
if guidance == "ref":
if render_type == "rgb":
# color loss. Use l2 loss in coarse and geometry satge; use l1 loss in texture stage.
if self.C(self.cfg.loss.lambda_rgb) > 0:
gt_rgb = gt_rgb * gt_mask.float() + out["comp_rgb_bg"] * (
1 - gt_mask.float()
)
pred_rgb = out["comp_rgb"]
if self.cfg.stage in ["coarse", "geometry"]:
set_loss("rgb", F.mse_loss(gt_rgb, pred_rgb))
else:
if self.cfg.stage == "texture":
grow_mask = F.max_pool2d(1 - gt_mask.float().permute(0, 3, 1, 2), (9, 9), 1, 4)
grow_mask = (1 - grow_mask).permute(0, 2, 3, 1)
set_loss("rgb", F.l1_loss(gt_rgb*grow_mask, pred_rgb*grow_mask))
else:
set_loss("rgb", F.l1_loss(gt_rgb, pred_rgb))
# mask loss
if self.C(self.cfg.loss.lambda_mask) > 0:
set_loss("mask", F.mse_loss(gt_mask.float(), out["opacity"]))
# mask binary cross loss
if self.C(self.cfg.loss.lambda_mask_binary) > 0:
set_loss("mask_binary", F.binary_cross_entropy(
out["opacity"].clamp(1.0e-5, 1.0 - 1.0e-5),
batch["mask"].float(),))
# depth loss
if self.C(self.cfg.loss.lambda_depth) > 0:
valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)].unsqueeze(1)
valid_pred_depth = out["depth"][gt_mask].unsqueeze(1)
with torch.no_grad():
A = torch.cat(
[valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1
) # [B, 2]
X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1]
valid_gt_depth = A @ X # [B, 1]
set_loss("depth", F.mse_loss(valid_gt_depth, valid_pred_depth))
# relative depth loss
if self.C(self.cfg.loss.lambda_depth_rel) > 0:
valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)] # [B,]
valid_pred_depth = out["depth"][gt_mask] # [B,]
set_loss(
"depth_rel", 1 - self.pearson(valid_pred_depth, valid_gt_depth)
)
# normal loss
if self.C(self.cfg.loss.lambda_normal) > 0:
valid_gt_normal = (
1 - 2 * gt_normal[gt_mask.squeeze(-1)]
) # [B, 3]
# FIXME: reverse x axis
pred_normal = out["comp_normal_viewspace"]
pred_normal[..., 0] = 1 - pred_normal[..., 0]
valid_pred_normal = (
2 * pred_normal[gt_mask.squeeze(-1)] - 1
) # [B, 3]
set_loss(
"normal",
1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean(),
)
elif guidance == "guidance" and self.true_global_step > self.cfg.freq.no_diff_steps:
if self.cfg.stage == "geometry" and render_type == "normal":
guidance_inp = out["comp_normal"]
else:
guidance_inp = out["comp_rgb"]
guidance_out = self.guidance(
guidance_inp,
prompt_utils,
**batch,
rgb_as_latents=False,
guidance_eval=guidance_eval,
mask=out["mask"] if "mask" in out else None,
)
for name, value in guidance_out.items():
self.log(f"train/{name}", value)
if name.startswith("loss_"):
set_loss(name.split("_")[-1], value)
if self.guidance_3d is not None:
# FIXME: use mixed camera config
if not self.cfg.use_mixed_camera_config or get_rank() % 2 == 0:
guidance_3d_out = self.guidance_3d(
out["comp_rgb"],
**batch,
rgb_as_latents=False,
guidance_eval=guidance_eval,
)
for name, value in guidance_3d_out.items():
if not (isinstance(value, torch.Tensor) and len(value.shape) > 0):
self.log(f"train/{name}_3d", value)
if name.startswith("loss_"):
set_loss("3d_"+name.split("_")[-1], value)
# set_loss("3d_sd", guidance_out["loss_sd"])
# Regularization
if self.C(self.cfg.loss.lambda_normal_smooth) > 0:
if "comp_normal" not in out:
raise ValueError(
"comp_normal is required for 2D normal smooth loss, no comp_normal is found in the output."
)
normal = out["comp_normal"]
set_loss(
"normal_smooth",
(normal[:, 1:, :, :] - normal[:, :-1, :, :]).square().mean()
+ (normal[:, :, 1:, :] - normal[:, :, :-1, :]).square().mean(),
)
if self.C(self.cfg.loss.lambda_3d_normal_smooth) > 0:
if "normal" not in out:
raise ValueError(
"Normal is required for normal smooth loss, no normal is found in the output."
)
if "normal_perturb" not in out:
raise ValueError(
"normal_perturb is required for normal smooth loss, no normal_perturb is found in the output."
)
normals = out["normal"]
normals_perturb = out["normal_perturb"]
set_loss("3d_normal_smooth", (normals - normals_perturb).abs().mean())
if self.cfg.stage == "coarse":
if self.C(self.cfg.loss.lambda_orient) > 0:
if "normal" not in out:
raise ValueError(
"Normal is required for orientation loss, no normal is found in the output."
)
set_loss(
"orient",
(
out["weights"].detach()
* dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2
).sum()
/ (out["opacity"] > 0).sum(),
)
if guidance != "ref" and self.C(self.cfg.loss.lambda_sparsity) > 0:
set_loss("sparsity", (out["opacity"] ** 2 + 0.01).sqrt().mean())
if self.C(self.cfg.loss.lambda_opaque) > 0:
opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
set_loss(
"opaque", binary_cross_entropy(opacity_clamped, opacity_clamped)
)
if "lambda_eikonal" in self.cfg.loss and self.C(self.cfg.loss.lambda_eikonal) > 0:
if "sdf_grad" not in out:
raise ValueError(
"SDF grad is required for eikonal loss, no normal is found in the output."
)
set_loss(
"eikonal", (
(torch.linalg.norm(out["sdf_grad"], ord=2, dim=-1) - 1.0) ** 2
).mean()
)
if "lambda_z_variance"in self.cfg.loss and self.C(self.cfg.loss.lambda_z_variance) > 0:
# z variance loss proposed in HiFA: http://arxiv.org/abs/2305.18766
# helps reduce floaters and produce solid geometry
loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean()
set_loss("z_variance", loss_z_variance)
elif self.cfg.stage == "geometry":
if self.C(self.cfg.loss.lambda_normal_consistency) > 0:
set_loss("normal_consistency", out["mesh"].normal_consistency())
if self.C(self.cfg.loss.lambda_laplacian_smoothness) > 0:
set_loss("laplacian_smoothness", out["mesh"].laplacian())
elif self.cfg.stage == "texture":
if self.C(self.cfg.loss.lambda_reg) > 0 and guidance == "guidance" and self.true_global_step % 5 == 0:
rgb = out["comp_rgb"]
rgb = F.interpolate(rgb.permute(0, 3, 1, 2), (512, 512), mode='bilinear').permute(0, 2, 3, 1)
control_prompt_utils = self.control_prompt_processor()
with torch.no_grad():
control_dict = self.control_guidance(
rgb=rgb,
cond_rgb=rgb,
prompt_utils=control_prompt_utils,
mask=out["mask"] if "mask" in out else None,
)
edit_images = control_dict["edit_images"]
temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8)
cv2.imwrite(".threestudio_cache/control_debug.jpg", temp[:, :, ::-1])
loss_reg = (rgb.shape[1] // 8) * (rgb.shape[2] // 8) * self.perceptual_loss(edit_images.permute(0, 3, 1, 2), rgb.permute(0, 3, 1, 2)).mean()
set_loss("reg", loss_reg)
else:
raise ValueError(f"Unknown stage {self.cfg.stage}")
loss = 0.0
for name, value in loss_terms.items():
self.log(f"train/{name}", value)
if name.startswith(loss_prefix):
loss_weighted = value * self.C(
self.cfg.loss[name.replace(loss_prefix, "lambda_")]
)
self.log(f"train/{name}_w", loss_weighted)
loss += loss_weighted
for name, value in self.cfg.loss.items():
self.log(f"train_params/{name}", self.C(value))
self.log(f"train/loss_{guidance}", loss)
if guidance_eval:
self.guidance_evaluation_save(
out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]],
guidance_out["eval"],
)
return {"loss": loss}