def render_img()

in src/plots.py [0:0]


def render_img(train_config, img_samples, img_name=None, model_idxs=None):
    targets = []
    imgs = []
    imgs_train_inference = []
    gt_depth = None
    gt_depth_world = None
    estimated_depth = None

    dim_h = train_config.dataset_info.h
    dim_w = train_config.dataset_info.w

    start_index = 0
    inference_chunk_size = train_config.config_file.inferenceChunkSize
    for batch in img_samples.batches(inference_chunk_size):
        img_parts, inf_dicts = train_config.inference(batch, gradient=False, is_inference=True)
        img_parts_train, inf_dicts_train = train_config.inference(batch, gradient=False, is_inference=False)

        # we create the tensors once and then only slice the results
        if len(imgs) == 0:
            for i in range(len(img_parts)):
                imgs.append(torch.zeros((dim_h * dim_w, img_parts[i].shape[-1]), device=train_config.device,
                                        dtype=torch.float32))
                imgs_train_inference.append(torch.zeros((dim_h * dim_w, img_parts_train[i].shape[-1]),
                                                        device=train_config.device, dtype=torch.float32))
                targets.append(torch.zeros((dim_h * dim_w, batch.get_train_target(i).shape[-1]),
                                           device=train_config.device, dtype=torch.float32))

            if estimated_depth is None and FeatureSetKeyConstants.nerf_estimated_depth in inf_dicts[-1]:
                estimated_depth = torch.zeros((dim_h * dim_w,
                                               inf_dicts[-1][FeatureSetKeyConstants.nerf_estimated_depth].shape[-1]),
                                              device=train_config.device, dtype=torch.float32)

            if gt_depth is None and FeatureSetKeyConstants.input_depth_groundtruth in inf_dicts_train[-1]:
                gt_depth = torch.zeros((dim_h * dim_w,
                                        inf_dicts_train[-1][FeatureSetKeyConstants.input_depth_groundtruth].shape[-1]),
                                       device=train_config.device, dtype=torch.float32)

            if gt_depth_world is None and FeatureSetKeyConstants.input_depth_groundtruth_world in inf_dicts_train[-1]:
                gt_depth_world = torch.zeros((dim_h * dim_w,
                                              inf_dicts_train[-1][FeatureSetKeyConstants.input_depth_groundtruth_world].shape[-1]),
                                             device=train_config.device, dtype=torch.float32)

        end_index = min(start_index + inference_chunk_size, dim_w * dim_h)

        for i in range(len(imgs)):
            imgs[i][start_index:end_index, :] = img_parts[i][:inference_chunk_size, :]
            imgs_train_inference[i][start_index:end_index, :] = img_parts_train[i][:inference_chunk_size, :]
            targets[i][start_index:end_index, :] = batch.get_train_target(i)[:inference_chunk_size, :]

        if estimated_depth is not None:
            estimated_depth[start_index:end_index, :] = \
                inf_dicts[-1][FeatureSetKeyConstants.nerf_estimated_depth][:inference_chunk_size, :]

        if gt_depth is not None:
            gt_depth[start_index:end_index, :] = \
                inf_dicts_train[-1][FeatureSetKeyConstants.input_depth_groundtruth][:inference_chunk_size, :]

        if gt_depth_world is not None:
            gt_depth_world[start_index:end_index, :] = \
                inf_dicts_train[-1][FeatureSetKeyConstants.input_depth_groundtruth_world][:inference_chunk_size, :]

        start_index = end_index

    if model_idxs is None:
        for i in range(len(imgs)):
            save_img(imgs[i], train_config.dataset_info, f"{train_config.logDir}{img_name}_{i}.png")
            save_img(imgs_train_inference[i], train_config.dataset_info, f"{train_config.logDir}{img_name}_{i}_train_input.png")
            save_img(targets[i], train_config.dataset_info, f"{train_config.logDir}{img_name}_{i}_train_targets.png")

    else:
        for i in range(len(model_idxs)):
            save_img(imgs[i], train_config.dataset_info, f"{train_config.logDir}{img_name}_{i}.png")
            save_img(imgs_train_inference[i], train_config.dataset_info, f"{train_config.logDir}{img_name}_{i}_train_input.png")
            save_img(targets[i], train_config.dataset_info, f"{train_config.logDir}{img_name}_{i}_train_targets.png")

    if gt_depth is not None:
        save_img(gt_depth, train_config.dataset_info, f"{train_config.logDir}{img_name}_gt_depth.png")

    # when pretraining, we do not render estimated depth, as the result would not be correct
    if estimated_depth is not None:
        save_img(estimated_depth, train_config.dataset_info, f"{train_config.logDir}{img_name}_estimated_depth.png")

    print(f'\nRender img PSNR {img_name}: {calculate_psnr(calculate_mse(targets[-1] - imgs[-1]))}\n')