in src/evaluate.py [0:0]
def generate_data(train_config, flags, reference_video=None):
count_flops = False
image_macs = []
image_macs_pp = []
dim_h = train_config.dataset_info.h
dim_w = train_config.dataset_info.w
if reference_video is not None:
# scale image data if config has scale != 1
if train_config.dataset_info.scale != 1:
video_scaled = np.zeros((len(reference_video), dim_h, dim_w, 3))
for i, img in enumerate(reference_video):
video_scaled[i] = cv2.resize(img, (dim_h, dim_w), interpolation=cv2.INTER_AREA)
reference_video = video_scaled
train_config.store_camera_options()
train_config.config_file.camPath = "cam_path"
train_config.config_file.camType = "PredefinedCamera"
train_config.config_file.videoFrames = -1
dataset = CameraViewCellDataset(train_config.config_file, train_config, train_config.dataset_info)
desc = "Generating diff and flip videos"
suffix = "video"
else:
os.makedirs(f"{train_config.outDir}/eval/", exist_ok=True)
if "complexity" in flags:
count_flops = True
for i in range(len(train_config.models)):
train_config.models[i] = flops_counter.add_flops_counting_methods(train_config.models[i])
dataset = train_config.test_dataset
desc = "Generating diff and flip images"
suffix = "images"
q_cont = QualityContainer()
# generate test image here to reduce time
for i in tqdm(range(len(dataset)), desc=desc, position=0, leave=True):
# acquire test and reference data
img_data = create_sample_wrapper(dataset[i], train_config, True)
if count_flops:
for k in range(len(train_config.models)):
train_config.models[k].start_flops_count(ost=None, verbose=False, ignore_list=[])
test = torch.zeros((dim_h * dim_w, 3), device=train_config.device, dtype=torch.float32)
start_index = 0
for batch in img_data.batches(train_config.config_file.inferenceChunkSize):
img_part, _ = train_config.inference(batch, gradient=False, is_inference=True)
end_index = min(start_index + train_config.config_file.inferenceChunkSize, dim_w * dim_h)
test[start_index:end_index, :3] = img_part[-1][:train_config.config_file.inferenceChunkSize, :3]
start_index = end_index
if count_flops:
total_macs = 0
for k in range(len(train_config.models)):
macs, params = train_config.models[k].compute_average_flops_cost()
total_macs += macs
train_config.models[k].stop_flops_count()
image_macs.append(total_macs * train_config.dataset_info.w * train_config.dataset_info.h)
image_macs_pp.append(total_macs)
# just so we can save the created images for debug purposes
if flags is not None and "output_images" in flags:
test_clone = torch.clone(test)
test_clone = t2np(test_clone)
test_clone = np.clip(test_clone.reshape(train_config.dataset_info.h,
train_config.dataset_info.w, -1), 0., 1.)[None]
if q_cont.out_data is None:
q_cont.out_data = test_clone
else:
q_cont.out_data = np.concatenate((q_cont.out_data, test_clone), axis=0)
del test_clone
if reference_video is None:
reference = img_data.get_train_target(-1)
else:
reference = (reference_video[i]).astype(np.float32)
reference = reference / 255
reference = torch.from_numpy(reference).to(train_config.device)
# generate data in separate functions to get rid of intermediate values
generate_diff_data(train_config, test, reference, q_cont, reference_video is None, flags)
if "flip" in flags:
generate_flip_data(train_config, test, reference, q_cont, reference_video is None)
# save data to disk
if reference_video is not None:
train_config.restore_camera_options()
print("Saving diff and flip videos")
imageio.mimwrite(os.path.join(train_config.outDir, f"_diff.mp4"), q_cont.diff_data, fps=30, quality=8)
imageio.mimwrite(os.path.join(train_config.outDir, f"_square_diff.mp4"), q_cont.square_diff_data, fps=30, quality=8)
if "flip" in flags:
imageio.mimwrite(os.path.join(train_config.outDir, f"_flip.mp4"), q_cont.flip_data, fps=30, quality=8)
if q_cont.out_data is not None:
imageio.mimwrite(os.path.join(train_config.outDir, f"_out.mp4"), q_cont.out_data, fps=30, quality=8)
else:
for i in tqdm(range(len(q_cont.diff_data)), desc="Saving diff and flip images", position=0, leave=True):
save_img(q_cont.diff_data[i], train_config.dataset_info, f"{train_config.outDir}/eval/{i}_diff_{q_cont.diff_data[i].mean()}.png")
save_img(q_cont.square_diff_data[i], train_config.dataset_info, f"{train_config.outDir}/eval/{i}_square_diff_{q_cont.square_diff_data[i].mean()}.png")
if "flip" in flags:
save_img(q_cont.flip_data[i], train_config.dataset_info, f"{train_config.outDir}/eval/{i}_flip_{q_cont.flip_data[i].mean()}.png")
if q_cont.out_data is not None:
save_img(q_cont.out_data[i], train_config.dataset_info, f"{train_config.outDir}/eval/{i}_out.png")
if count_flops:
with open(os.path.join(train_config.outDir, "complexity.txt"), "w") as f:
cma_macs = 0
cma_macs_pp = 0
for idx in range(len(image_macs)):
macs = image_macs[idx]
macs_pp = image_macs_pp[idx]
f.write(f"{idx} - {macs} - {macs_pp}\n")
# to possibly avoid overflows, we calculate the cumulative moving averages only
cma_macs = cma_macs + (macs - cma_macs) / (idx + 1)
cma_macs_pp = cma_macs_pp + (macs_pp - cma_macs_pp) / (idx + 1)
f.write(f"{cma_macs} : {cma_macs_pp}\n")
# write quality info to .txt and .csv file
with open(f"{train_config.outDir}/image_quality_{suffix}.txt", "w") as f:
for idx, mse in enumerate(q_cont.mse):
f.write(f"image={idx} mse={mse:.4f} psnr="
f"{q_cont.psnr[idx] if 'psnr' in flags else -1.:.4f} "
f"ssim="
f"{q_cont.ssim[idx] if 'ssim' in flags else -1.:.4f} "
f"flip_loss="
f"{q_cont.flip[idx] if 'flip' in flags else -1.:.4f}\r")
with open(f"{train_config.outDir}/image_quality_{suffix}.csv", "w") as c:
c.write(f"mse,psnr,ssim,flip\r")
for idx, mse in enumerate(q_cont.mse):
c.write(f"{mse},{q_cont.psnr[idx] if 'psnr' in flags else -1.},"
f"{q_cont.ssim[idx] if 'ssim' in flags else -1.},"
f"{q_cont.flip[idx] if 'flip' in flags else -1.}\r")