def main()

in extra_scripts/generate_jigsaw_permutations.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="Permutations for patches")
    parser.add_argument("--N", type=int, default=1000, help="Number of permuations")
    parser.add_argument("--M", type=int, default=9, help="Number of patches to permute")
    parser.add_argument(
        "--method",
        type=str,
        default="max_avg",
        choices=["max_avg", "max_min"],
        help="hamming distance : max_avg, max_min",
    )
    parser.add_argument(
        "--min_distance",
        type=float,
        default=2.0 / 9.0,
        help="min distance of permutations in final set",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        required=True,
        help="Output directory where permutations should be saved",
    )
    args = parser.parse_args()

    # now generate data permutation for num_perms, num_patches and save them.
    # The algorithm followed is same as in https://arxiv.org/pdf/1603.09246.pdf
    # Algorithm 1 on page 12.
    logger.info(
        f"Generating all perms: M (#patches): {args.M}, "
        f"N (#perms): {args.N}, method: {args.method}"
    )
    num_perms, num_patches = args.N, args.M
    all_perms = np.array(list(itertools.permutations(list(range(num_patches)))))
    total_perms = all_perms.shape[0]

    logger.info(f"Selecting perms from set of {total_perms} perms")
    for idx in range(num_perms):
        if idx == 0:
            j = np.random.randint(total_perms)  # uniformly sample first perm
            selected_perms = all_perms[j].reshape([1, -1])
        else:
            selected_perms = np.concatenate(
                [selected_perms, all_perms[j].reshape([1, -1])], axis=0
            )
        all_perms = np.delete(all_perms, j, axis=0)
        # compute the hamming distance now between the remaining and selected
        D = cdist(selected_perms, all_perms, metric="hamming")
        if args.method == "max_avg":
            D = D.mean(axis=0)
            j = D.argmax()
        elif args.method == "max_min":
            min_to_selected = D.min(axis=0)
            j = min_to_selected.argmax()
            if min_to_selected.min() < args.min_distance:
                logger.info(
                    f"min distance {min_to_selected.min()} "
                    f"< threshold {args.min_distance}"
                )
        elif args.method == "avg":
            logger.info("not implemented yet")
        if (idx + 1) % 100 == 0:
            logger.info(f"selected_perms: {(idx + 1)} -> {selected_perms.shape}")

    dists_sel = cdist(selected_perms, selected_perms, metric="hamming")
    non_diag_elements = dists_sel[np.where(np.eye(dists_sel.shape[0]) != 1)]
    mean_dist = non_diag_elements.mean() * selected_perms.shape[1]
    min_dist = non_diag_elements.min() * selected_perms.shape[1]
    logger.info(f"Permutation stats: avg dist {mean_dist}; min dist {min_dist}")

    perm_file = (
        f"{args.output_dir}/hamming_perms_{args.N}_patches_{args.M}_{ args.method}.npy"
    )
    logger.info(f"Writing permutations to: {perm_file}")
    logger.info(f"permutations shape: {selected_perms.shape}")
    np.save(perm_file, selected_perms)
    logger.info("Done!")