in projects/deep_video_compression/_utils.py [0:0]
def compute_batch_loss(self, image1, image2):
assert self.loss_functions.entropy_fn is not None
assert image1.ndim == image2.ndim == 4
output = self.forward(image1, image2)
# compute distortion loss
distortion_loss = self.loss_functions.distortion_fn(image2, output.image2_est)
# compute flow compression loss, average over num pixels
num_pixels = image1.shape[0] * image1.shape[-2] * image1.shape[-1]
flow_entropy_loss = self.loss_functions.entropy_fn(
output.flow_probabilities, num_pixels
)
# compute resid compression loss, average over num pixels
resid_entropy_loss = self.loss_functions.entropy_fn(
output.resid_probabilities, num_pixels
)
return (
LossValues(
distortion_loss=distortion_loss,
flow_entropy_loss=flow_entropy_loss,
resid_entropy_loss=resid_entropy_loss,
),
OutputTensors(
flow=output.flow.detach(),
image1=image1.detach(),
image2=image2.detach(),
image2_est=output.image2_est.detach(),
),
)