def render_all_imgs()

in src/plots.py [0:0]


def render_all_imgs(train_config: TrainConfig, subfolder_name="", dataset_name="test"):
    os.makedirs(os.path.join(train_config.logDir, subfolder_name, dataset_name), exist_ok=True)

    inference_chunk_size = train_config.config_file.inferenceChunkSize
    data_set, _ = train_config.get_data_set_and_loader(dataset_name)
    saved_full_images = data_set.full_images
    data_set.full_images = True

    psnrs = []

    dim_w = train_config.dataset_info.w
    dim_h = train_config.dataset_info.h

    # we use the dataset here, as using the data_loader would make it necessary to handle all the different batch sizes
    for i, sample_data in enumerate(tqdm(data_set, desc=f"rendering all images ({dataset_name})", position=0, leave=True)):
        img_samples = create_sample_wrapper(sample_data, train_config, True)

        imgs = []
        target = None
        inference_dict_full_list = []

        start_index = 0
        for batch in img_samples.batches(inference_chunk_size):
            img_part, inference_dict_part = train_config.inference(batch, gradient=False, is_inference=True)
            if len(imgs) == 0:
                for j in range(len(img_part)):
                    imgs.append(torch.zeros((dim_h * dim_w, img_part[j].shape[-1]), device=train_config.device,
                                            dtype=torch.float32))
                    inference_dict_full = {}
                    if len(inference_dict_part) > 0:
                        for key, value in inference_dict_part[j].items():
                            if key in train_config.config_file.outputNetworkRaw:
                                inference_dict_full[key] = torch.zeros((dim_h * dim_w, value.shape[-1]), device="cpu",
                                                                       dtype=torch.float32)
                    inference_dict_full_list.append(inference_dict_full)

                target = torch.zeros((dim_h * dim_w, batch.get_train_target(-1).shape[-1]),
                                     device=train_config.device, dtype=torch.float32)

            end_index = min(start_index + train_config.config_file.inferenceChunkSize, dim_w * dim_h)

            for j in range(len(img_part)):
                imgs[j][start_index:end_index] = img_part[j][:inference_chunk_size]

                for key, value in inference_dict_part[j].items():
                    if key in train_config.config_file.outputNetworkRaw:
                        if inference_dict_full_list[j][key].ndim != 0:
                            inference_dict_full_list[j][key][start_index:end_index] = (value[:inference_chunk_size])

            target[start_index:end_index, :] = batch.get_train_target(-1)[:inference_chunk_size, :]

            start_index = end_index

        # Reshape all values to [h, w] from dict

        for j in range(len(inference_dict_full_list)):
            for key in inference_dict_full_list[j]:
                if inference_dict_full_list[j][key].ndim != 0 and FeatureSetKeyConstants.input_depth_range not in key:
                    inference_dict_full_list[j][key] = torch.reshape(inference_dict_full_list[j][key], [train_config.dataset_info.h, train_config.dataset_info.w, *inference_dict_full_list[j][key].shape[1:]])

        for net_idx, img in enumerate(imgs):
            save_img(img, train_config.dataset_info, f"{train_config.logDir}{subfolder_name}{dataset_name}/_{net_idx}_{i}.png")

        if FeatureSetKeyConstants.input_depth_groundtruth in inference_dict_full_list[-1]:
            save_img(inference_dict_full_list[-1][FeatureSetKeyConstants.input_depth_groundtruth], train_config.dataset_info, f"{train_config.logDir}{subfolder_name}{dataset_name}/_{i}_input_depth_gth.png")

        if FeatureSetKeyConstants.nerf_estimated_depth in inference_dict_full_list[-1]:
            save_img(inference_dict_full_list[-1][FeatureSetKeyConstants.nerf_estimated_depth], train_config.dataset_info, f"{train_config.logDir}{subfolder_name}{dataset_name}/_{i}_estimated_depth.png")

        raw_save_suffix = ""
        if "lin" not in train_config.config_file.depthTransform:
            raw_save_suffix += train_config.config_file.depthTransform[0:2]

        if train_config.config_file.rayMarchNormalization is not None and len(train_config.config_file.rayMarchNormalization) > 0:
            raw_save_suffix += nerf_get_normalization_function_abbr(train_config.config_file.rayMarchNormalization[-1])

        if FeatureSetKeyConstants.nerf_estimated_depth in inference_dict_full_list[-1]:
            # Load depth range and depth map
            depth_range = inference_dict_full_list[-1][FeatureSetKeyConstants.input_depth_range]
            depth_map = inference_dict_full_list[-1][FeatureSetKeyConstants.nerf_estimated_depth]

            # In case the depth range contains more than 2 elements due to the export
            input_depth_range = depth_range[:2]

            world_depth = train_config.f_in[-1].depth_transform.to_world(depth_map, input_depth_range[-1])
            np.savez(f"{train_config.logDir}{subfolder_name}{dataset_name}/{i:05d}_depth.npz", world_depth)
            save_img(depth_map[..., None], train_config.dataset_info, f"{train_config.logDir}{subfolder_name}{dataset_name}/{i}_{raw_save_suffix}_depth.png")
        else:
            for j in range(len(inference_dict_full_list)):
                for key in inference_dict_full_list[j]:
                    torch.save(inference_dict_full_list[j][key],
                               f"{train_config.logDir}{subfolder_name}{dataset_name}/{i}_{j}_{key}_{raw_save_suffix}.raw")

        psnrs.append(calculate_psnr(calculate_mse(target - imgs[-1])))

    print("\n")
    psnrs_np = []
    for i in range(len(psnrs)):
        print(f"Render all img psnr {i} {psnrs[i]}")
        psnrs_np.append(psnrs[i].cpu().numpy())

    print(f"Average PSNR: {np.array(psnrs_np).mean()}")

    data_set.full_images = saved_full_images