def training_step()

in projects/deep_video_compression/dvc_module.py [0:0]


    def training_step(self, batch, batch_idx):
        log_key = f"{self.training_stage}/train_"

        if isinstance(self.optimizers(), list):
            [opt1, opt2] = self.optimizers()
        else:
            opt1 = self.optimizers()
            opt2 = None

        # compress the iframe and get its bpp cost (no grads)
        batch, iframe_bpp = self.compress_iframe(batch)

        # update main model params
        # gop = "Group of Pictures"
        logging_metrics = LoggingMetrics(
            gop_total_loss=0,
            gop_distortion_loss=0,
            gop_bpp=iframe_bpp,
            gop_flow_bpp=0,
            gop_residual_bpp=0,
        )
        image2_list = []
        image2_est_list = []

        image1 = batch[:, 0]
        for i in range(self.num_pframes):
            opt1.zero_grad()  # we backprop for every P-frame
            image2 = batch[:, i + 1]
            loss_values, images = self.model.compute_batch_loss(image1, image2)
            image1 = images.image2_est  # images are detached

            # keep track of these for other distortion metrics
            # note: these have no grads
            image2_list.append(images.image2)
            image2_est_list.append(images.image2_est)

            # loss function collection
            loss, logging_metrics = self.compute_loss_and_metrics(
                loss_values, logging_metrics
            )

            self.manual_backward(loss)
            torch.nn.utils.clip_grad_norm_(
                opt1.param_groups[0]["params"], self.grad_clip_value
            )
            opt1.step()

        # lr step
        if self.lr_schedulers() is not None:
            self.lr_schedulers().step()

        # stat reductions and logging
        reduction = self.num_pframes + 1
        self.log_all_metrics(
            log_key, reduction, image2_list, image2_est_list, logging_metrics
        )

        # auxiliary update
        # this is the loss for learning the quantiles of the bottlenecks.
        if opt2 is not None:
            opt2.zero_grad()
            aux_loss = self.model.quantile_loss()
            self.log(f"{log_key}quantile_loss", aux_loss, sync_dist=True)
            self.manual_backward(aux_loss)
            opt2.step()