in quality_comparison/measure_reconstruction_completeness.py [0:0]
def measure_reconstruction_completeness(args):
pool = mp.Pool(args.num_workers)
stats = []
per_image_stats = []
os.makedirs(args.save_dir, exist_ok=True)
stats_path = f"{args.save_dir}/dataset_stats.csv"
per_image_stats_path = f"{args.save_dir}/dataset_image_stats.csv"
if not osp.isfile(stats_path):
for dataset_name, json_path in zip(args.dataset_names, args.json_paths):
print(f"=======> Evaluating {dataset_name}")
with open(json_path, "r") as f:
dataset_info = json.load(f)
# Update dataset_info to include mode
assert (
"mode" not in dataset_info[0]
), "dataset_info already contains the key 'mode'"
dataset_info = [{"mode": args.mode, **di} for di in dataset_info]
# Compute stats over the complete dataset
dataset_stats = list(
tqdm.tqdm(
pool.imap(is_image_defective, dataset_info), total=len(dataset_info)
)
)
# Compute scene-specific stats
scene_stats = defaultdict(list)
for info, (scene_name, has_defect, frac_defect) in zip(
dataset_info, dataset_stats
):
scene_stats[scene_name].append(
(has_defect, frac_defect, info["rgb_path"], info["depth_path"])
)
for scene_name, defects_info in scene_stats.items():
defects = [di[0] for di in defects_info]
stats.append(
{
"scene": scene_name,
"% defects": np.mean(defects).item() * 100.0,
"dataset": dataset_name,
}
)
per_image_stats += [
{
"scene": scene_name,
"has defect": di[0],
"frac defects": di[1],
"rgb_path": di[2],
"depth_path": di[3],
}
for di in defects_info
]
stats = pd.DataFrame(stats)
stats.to_csv(stats_path, index=False)
per_image_stats = pd.DataFrame(per_image_stats)
per_image_stats.to_csv(per_image_stats_path, index=False)
else:
stats = pd.read_csv(stats_path, index_col=False)
per_image_stats = pd.read_csv(per_image_stats_path, index_col=False)
plt.figure(figsize=args.figsize)
sns.histplot(
stats,
x="% defects",
element="step",
hue="dataset",
fill=False,
bins=25,
palette=[COLOR_MAPPING[d] for d in args.dataset_names],
)
plt.yscale("log")
for label in plt.xticks()[1] + plt.yticks()[1]:
label.set_fontproperties(ticks_font)
plt.xlabel("% defects", fontdict=axes_font)
plt.ylabel("# scenes", fontdict=axes_font)
ax = plt.gca()
plt.xlim(0, 100)
plt.ylim(1, 1000)
ax.xaxis.set_major_locator(MultipleLocator(20))
ax.yaxis.set_major_locator(LogLocator())
plt.tight_layout()
plt.savefig(f"{args.save_dir}/histplot.png")
# Print average stats per dataset
grouped_stats = stats.set_index("% defects", drop=True).groupby("dataset").groups
for k, v in grouped_stats.items():
v_mean = np.mean(v)
print(f"{k}: {v_mean:.4f}")