in src/plots.py [0:0]
def render_all_imgs(train_config: TrainConfig, subfolder_name="", dataset_name="test"):
os.makedirs(os.path.join(train_config.logDir, subfolder_name, dataset_name), exist_ok=True)
inference_chunk_size = train_config.config_file.inferenceChunkSize
data_set, _ = train_config.get_data_set_and_loader(dataset_name)
saved_full_images = data_set.full_images
data_set.full_images = True
psnrs = []
dim_w = train_config.dataset_info.w
dim_h = train_config.dataset_info.h
# we use the dataset here, as using the data_loader would make it necessary to handle all the different batch sizes
for i, sample_data in enumerate(tqdm(data_set, desc=f"rendering all images ({dataset_name})", position=0, leave=True)):
img_samples = create_sample_wrapper(sample_data, train_config, True)
imgs = []
target = None
inference_dict_full_list = []
start_index = 0
for batch in img_samples.batches(inference_chunk_size):
img_part, inference_dict_part = train_config.inference(batch, gradient=False, is_inference=True)
if len(imgs) == 0:
for j in range(len(img_part)):
imgs.append(torch.zeros((dim_h * dim_w, img_part[j].shape[-1]), device=train_config.device,
dtype=torch.float32))
inference_dict_full = {}
if len(inference_dict_part) > 0:
for key, value in inference_dict_part[j].items():
if key in train_config.config_file.outputNetworkRaw:
inference_dict_full[key] = torch.zeros((dim_h * dim_w, value.shape[-1]), device="cpu",
dtype=torch.float32)
inference_dict_full_list.append(inference_dict_full)
target = torch.zeros((dim_h * dim_w, batch.get_train_target(-1).shape[-1]),
device=train_config.device, dtype=torch.float32)
end_index = min(start_index + train_config.config_file.inferenceChunkSize, dim_w * dim_h)
for j in range(len(img_part)):
imgs[j][start_index:end_index] = img_part[j][:inference_chunk_size]
for key, value in inference_dict_part[j].items():
if key in train_config.config_file.outputNetworkRaw:
if inference_dict_full_list[j][key].ndim != 0:
inference_dict_full_list[j][key][start_index:end_index] = (value[:inference_chunk_size])
target[start_index:end_index, :] = batch.get_train_target(-1)[:inference_chunk_size, :]
start_index = end_index
# Reshape all values to [h, w] from dict
for j in range(len(inference_dict_full_list)):
for key in inference_dict_full_list[j]:
if inference_dict_full_list[j][key].ndim != 0 and FeatureSetKeyConstants.input_depth_range not in key:
inference_dict_full_list[j][key] = torch.reshape(inference_dict_full_list[j][key], [train_config.dataset_info.h, train_config.dataset_info.w, *inference_dict_full_list[j][key].shape[1:]])
for net_idx, img in enumerate(imgs):
save_img(img, train_config.dataset_info, f"{train_config.logDir}{subfolder_name}{dataset_name}/_{net_idx}_{i}.png")
if FeatureSetKeyConstants.input_depth_groundtruth in inference_dict_full_list[-1]:
save_img(inference_dict_full_list[-1][FeatureSetKeyConstants.input_depth_groundtruth], train_config.dataset_info, f"{train_config.logDir}{subfolder_name}{dataset_name}/_{i}_input_depth_gth.png")
if FeatureSetKeyConstants.nerf_estimated_depth in inference_dict_full_list[-1]:
save_img(inference_dict_full_list[-1][FeatureSetKeyConstants.nerf_estimated_depth], train_config.dataset_info, f"{train_config.logDir}{subfolder_name}{dataset_name}/_{i}_estimated_depth.png")
raw_save_suffix = ""
if "lin" not in train_config.config_file.depthTransform:
raw_save_suffix += train_config.config_file.depthTransform[0:2]
if train_config.config_file.rayMarchNormalization is not None and len(train_config.config_file.rayMarchNormalization) > 0:
raw_save_suffix += nerf_get_normalization_function_abbr(train_config.config_file.rayMarchNormalization[-1])
if FeatureSetKeyConstants.nerf_estimated_depth in inference_dict_full_list[-1]:
# Load depth range and depth map
depth_range = inference_dict_full_list[-1][FeatureSetKeyConstants.input_depth_range]
depth_map = inference_dict_full_list[-1][FeatureSetKeyConstants.nerf_estimated_depth]
# In case the depth range contains more than 2 elements due to the export
input_depth_range = depth_range[:2]
world_depth = train_config.f_in[-1].depth_transform.to_world(depth_map, input_depth_range[-1])
np.savez(f"{train_config.logDir}{subfolder_name}{dataset_name}/{i:05d}_depth.npz", world_depth)
save_img(depth_map[..., None], train_config.dataset_info, f"{train_config.logDir}{subfolder_name}{dataset_name}/{i}_{raw_save_suffix}_depth.png")
else:
for j in range(len(inference_dict_full_list)):
for key in inference_dict_full_list[j]:
torch.save(inference_dict_full_list[j][key],
f"{train_config.logDir}{subfolder_name}{dataset_name}/{i}_{j}_{key}_{raw_save_suffix}.raw")
psnrs.append(calculate_psnr(calculate_mse(target - imgs[-1])))
print("\n")
psnrs_np = []
for i in range(len(psnrs)):
print(f"Render all img psnr {i} {psnrs[i]}")
psnrs_np.append(psnrs[i].cpu().numpy())
print(f"Average PSNR: {np.array(psnrs_np).mean()}")
data_set.full_images = saved_full_images