in projects/deep_video_compression/dvc_module.py [0:0]
def validation_step(self, batch, batch_idx):
log_key = f"{self.training_stage}/val_"
# gop = "Group of Pictures"
logging_metrics = LoggingMetrics(
gop_total_loss=0,
gop_distortion_loss=0,
gop_bpp=0,
gop_flow_bpp=0,
gop_residual_bpp=0,
)
image2_list = []
image2_est_list = []
batch, gop_bpp = self.compress_iframe(batch) # bpp_total w/o grads
image1 = batch[:, 0]
for i in range(self.num_pframes):
image2 = batch[:, i + 1]
loss_values, images = self.model.compute_batch_loss(image1, image2)
image1 = images.image2_est # images are detached
# keep track of these for other distortion metrics
image2_list.append(images.image2)
image2_est_list.append(images.image2_est)
# loss function collection
loss, logging_metrics = self.compute_loss_and_metrics(
loss_values, logging_metrics
)
# stat reductions and logging
reduction = self.num_pframes + 1
self.log("val_loss", loss)
self.log_all_metrics(
log_key, reduction, image2_list, image2_est_list, logging_metrics
)