in projects/deep_video_compression/dvc_module.py [0:0]
def training_step(self, batch, batch_idx):
log_key = f"{self.training_stage}/train_"
if isinstance(self.optimizers(), list):
[opt1, opt2] = self.optimizers()
else:
opt1 = self.optimizers()
opt2 = None
# compress the iframe and get its bpp cost (no grads)
batch, iframe_bpp = self.compress_iframe(batch)
# update main model params
# gop = "Group of Pictures"
logging_metrics = LoggingMetrics(
gop_total_loss=0,
gop_distortion_loss=0,
gop_bpp=iframe_bpp,
gop_flow_bpp=0,
gop_residual_bpp=0,
)
image2_list = []
image2_est_list = []
image1 = batch[:, 0]
for i in range(self.num_pframes):
opt1.zero_grad() # we backprop for every P-frame
image2 = batch[:, i + 1]
loss_values, images = self.model.compute_batch_loss(image1, image2)
image1 = images.image2_est # images are detached
# keep track of these for other distortion metrics
# note: these have no grads
image2_list.append(images.image2)
image2_est_list.append(images.image2_est)
# loss function collection
loss, logging_metrics = self.compute_loss_and_metrics(
loss_values, logging_metrics
)
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(
opt1.param_groups[0]["params"], self.grad_clip_value
)
opt1.step()
# lr step
if self.lr_schedulers() is not None:
self.lr_schedulers().step()
# stat reductions and logging
reduction = self.num_pframes + 1
self.log_all_metrics(
log_key, reduction, image2_list, image2_est_list, logging_metrics
)
# auxiliary update
# this is the loss for learning the quantiles of the bottlenecks.
if opt2 is not None:
opt2.zero_grad()
aux_loss = self.model.quantile_loss()
self.log(f"{log_key}quantile_loss", aux_loss, sync_dist=True)
self.manual_backward(aux_loss)
opt2.step()