in threestudio/systems/dreamcraft3d.py [0:0]
def training_step(self, batch, batch_idx):
if self.cfg.freq.ref_or_guidance == "accumulate":
do_ref = True
do_guidance = True
elif self.cfg.freq.ref_or_guidance == "alternate":
do_ref = (
self.true_global_step < self.cfg.freq.ref_only_steps
or self.true_global_step % self.cfg.freq.n_ref == 0
)
do_guidance = not do_ref
if hasattr(self.guidance.cfg, "only_pretrain_step"):
if (self.guidance.cfg.only_pretrain_step > 0) and (self.global_step % self.guidance.cfg.only_pretrain_step) < (self.guidance.cfg.only_pretrain_step // 5):
do_guidance = True
do_ref = False
if self.cfg.stage == "geometry":
render_type = "rgb" if self.true_global_step % self.cfg.freq.n_rgb == 0 else "normal"
else:
render_type = "rgb"
total_loss = 0.0
if do_guidance:
out = self.training_substep(batch, batch_idx, guidance="guidance", render_type=render_type)
total_loss += out["loss"]
if do_ref:
out = self.training_substep(batch, batch_idx, guidance="ref", render_type=render_type)
total_loss += out["loss"]
self.log("train/loss", total_loss, prog_bar=True)
# sch = self.lr_schedulers()
# sch.step()
return {"loss": total_loss}