def measure_reconstruction_completeness()

in quality_comparison/measure_reconstruction_completeness.py [0:0]


def measure_reconstruction_completeness(args):
    pool = mp.Pool(args.num_workers)

    stats = []
    per_image_stats = []
    os.makedirs(args.save_dir, exist_ok=True)
    stats_path = f"{args.save_dir}/dataset_stats.csv"
    per_image_stats_path = f"{args.save_dir}/dataset_image_stats.csv"
    if not osp.isfile(stats_path):
        for dataset_name, json_path in zip(args.dataset_names, args.json_paths):
            print(f"=======> Evaluating {dataset_name}")
            with open(json_path, "r") as f:
                dataset_info = json.load(f)
            # Update dataset_info to include mode
            assert (
                "mode" not in dataset_info[0]
            ), "dataset_info already contains the key 'mode'"
            dataset_info = [{"mode": args.mode, **di} for di in dataset_info]
            # Compute stats over the complete dataset
            dataset_stats = list(
                tqdm.tqdm(
                    pool.imap(is_image_defective, dataset_info), total=len(dataset_info)
                )
            )
            # Compute scene-specific stats
            scene_stats = defaultdict(list)
            for info, (scene_name, has_defect, frac_defect) in zip(
                dataset_info, dataset_stats
            ):
                scene_stats[scene_name].append(
                    (has_defect, frac_defect, info["rgb_path"], info["depth_path"])
                )
            for scene_name, defects_info in scene_stats.items():
                defects = [di[0] for di in defects_info]
                stats.append(
                    {
                        "scene": scene_name,
                        "% defects": np.mean(defects).item() * 100.0,
                        "dataset": dataset_name,
                    }
                )
                per_image_stats += [
                    {
                        "scene": scene_name,
                        "has defect": di[0],
                        "frac defects": di[1],
                        "rgb_path": di[2],
                        "depth_path": di[3],
                    }
                    for di in defects_info
                ]
        stats = pd.DataFrame(stats)
        stats.to_csv(stats_path, index=False)
        per_image_stats = pd.DataFrame(per_image_stats)
        per_image_stats.to_csv(per_image_stats_path, index=False)
    else:
        stats = pd.read_csv(stats_path, index_col=False)
        per_image_stats = pd.read_csv(per_image_stats_path, index_col=False)

    plt.figure(figsize=args.figsize)
    sns.histplot(
        stats,
        x="% defects",
        element="step",
        hue="dataset",
        fill=False,
        bins=25,
        palette=[COLOR_MAPPING[d] for d in args.dataset_names],
    )
    plt.yscale("log")
    for label in plt.xticks()[1] + plt.yticks()[1]:
        label.set_fontproperties(ticks_font)
    plt.xlabel("% defects", fontdict=axes_font)
    plt.ylabel("# scenes", fontdict=axes_font)
    ax = plt.gca()
    plt.xlim(0, 100)
    plt.ylim(1, 1000)
    ax.xaxis.set_major_locator(MultipleLocator(20))
    ax.yaxis.set_major_locator(LogLocator())
    plt.tight_layout()

    plt.savefig(f"{args.save_dir}/histplot.png")

    # Print average stats per dataset
    grouped_stats = stats.set_index("% defects", drop=True).groupby("dataset").groups
    for k, v in grouped_stats.items():
        v_mean = np.mean(v)
        print(f"{k}: {v_mean:.4f}")