def main()

in point_e/evals/scripts/evaluate_pfid.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache_dir", type=str, default=None)
    parser.add_argument("batch_1", type=str)
    parser.add_argument("batch_2", type=str)
    args = parser.parse_args()

    print("creating classifier...")
    clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir)

    print("computing first batch activations")

    features_1, _ = clf.features_and_preds(NpzStreamer(args.batch_1))
    stats_1 = compute_statistics(features_1)
    del features_1

    features_2, _ = clf.features_and_preds(NpzStreamer(args.batch_2))
    stats_2 = compute_statistics(features_2)
    del features_2

    print(f"P-FID: {stats_1.frechet_distance(stats_2)}")