in depth_fine_tuning.py [0:0]
def eval_and_save(self, criterion, data_loader, suf) -> Dict[str, torch.Tensor]:
"""
Note this function asssumes the structure of the data produced by data_loader
"""
N = len(data_loader.dataset)
loss_dict = {}
saved_frames = set()
total_index = 0
max_frame_index = 0
all_pairs = []
for _, data in zip(range(N), data_loader):
data = to_device(data)
stacked_img, metadata = data
with torch.no_grad():
depth = self.model(stacked_img, metadata)
batch_indices = (
metadata["geometry_consistency"]["indices"].cpu().numpy().tolist()
)
# Update the maximum frame index and pairs list.
max_frame_index = max(max_frame_index, max(itertools.chain(*batch_indices)))
all_pairs += batch_indices
# Compute and store losses.
_, loss_meta = criterion(
depth, metadata, parameters=self.model.parameters(),
)
for loss_name, losses in loss_meta.items():
if loss_name not in loss_dict:
loss_dict[loss_name] = {}
for indices, loss in zip(batch_indices, losses):
loss_dict[loss_name][str(indices)] = loss.item()
# Save depth maps.
inv_depths_batch = 1.0 / depth.cpu().detach().numpy()
if self.vis_depth_scale is None:
# Single scale for the whole dataset.
self.vis_depth_scale = inv_depths_batch.max()
for inv_depths, indices in zip(inv_depths_batch, batch_indices):
for inv_depth, index in zip(inv_depths, indices):
# Only save frames not saved before.
if index in saved_frames:
continue
saved_frames.add(index)
fn_pre = pjoin(
self.out_dir, "eval", "depth_{:06d}{}".format(index, suf)
)
image_io.save_raw_float32_image(fn_pre + ".raw", inv_depth)
inv_depth_vis = visualization.visualize_depth(
inv_depth, depth_min=0, depth_max=self.vis_depth_scale
)
cv2.imwrite(fn_pre + ".png", inv_depth_vis)
total_index += 1
loss_meta = {
loss_name: torch.tensor(tuple(loss.values()))
for loss_name, loss in loss_dict.items()
}
loss_dict["mean"] = {k: v.mean().item() for k, v in loss_meta.items()}
with open(pjoin(self.out_dir, "eval", "loss{}.json".format(suf)), "w") as f:
json.dump(loss_dict, f)
# Print verbose summary to stdout.
index_width = int(math.ceil(math.log10(max_frame_index)))
loss_names = list(loss_dict.keys())
loss_names.remove("mean")
loss_format = {}
for name in loss_names:
max_value = max(loss_dict[name].values())
width = math.ceil(math.log10(max_value))
loss_format[name] = f"{width+7}.6f"
for pair in sorted(all_pairs):
line = f"({pair[0]:{index_width}d}, {pair[1]:{index_width}d}): "
line += ", ".join(
[f"{name}: {loss_dict[name][str(pair)]:{loss_format[name]}}"
for name in loss_names]
)
print(line)
print("Mean: " + " " * (2 * index_width) + ", ".join(
[f"{name}: {loss_dict[name][str(pair)]:{loss_format[name]}}"
for name in loss_names]
))
return loss_meta