in extra_scripts/generate_jigsaw_permutations.py [0:0]
def main():
parser = argparse.ArgumentParser(description='Permutations for patches')
parser.add_argument('--N', type=int, default=1000,
help='Number of permuations')
parser.add_argument('--M', type=int, default=9,
help='Number of patches to permute')
parser.add_argument('--method', type=str, default='max_avg',
choices=['max_avg', 'max_min'],
help='hamming distance : max, avg')
parser.add_argument('--min_distance', type=float, default=2.0 / 9.0,
help='min distance of permutations in final set')
parser.add_argument('--output_dir', type=str, default=None,
help='output dir')
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
# now generate data permutation for num_perms, num_patches and save them.
# The algorithm followed is same as in https://arxiv.org/pdf/1603.09246.pdf
# Algorithm 1 on page 12.
logger.info("Generating all perms....")
num_perms, num_patches = args.N, args.M
all_perms = np.array(list(itertools.permutations(list(range(num_patches)))))
total_perms = all_perms.shape[0]
logger.info("Selecting perms from set of {} perms".format(total_perms))
for idx in range(num_perms):
if idx == 0:
j = np.random.randint(total_perms) # uniformaly sample first perm
selected_perms = all_perms[j].reshape([1, -1])
else:
selected_perms = np.concatenate(
[selected_perms, all_perms[j].reshape([1, -1])], axis=0)
all_perms = np.delete(all_perms, j, axis=0)
# compute the hamming distance now between the remaining and selected
D = cdist(selected_perms, all_perms, metric='hamming')
if args.method == 'max_avg':
D = D.mean(axis=0)
j = D.argmax()
elif args.method == 'max_min':
min_to_selected = D.min(axis=0)
j = min_to_selected.argmax()
if min_to_selected.min() < args.min_distance:
logger.info("min distance {} < threshold {}".format(
min_to_selected.min(), args.min_distance))
elif args.method == 'avg':
logger.info('not implemented yet')
if (idx + 1) % 100 == 0:
logger.info('selected_perms: {} -> {}'.format(
(idx + 1), selected_perms.shape))
dists_sel = cdist(selected_perms, selected_perms, metric='hamming')
non_diag_elements = dists_sel[np.where(np.eye(dists_sel.shape[0]) != 1)]
mean_dist = non_diag_elements.mean() * selected_perms.shape[1]
min_dist = non_diag_elements.min() * selected_perms.shape[1]
logger.info('Permutation stats: avg dist {}; min dist {}'.format(
mean_dist, min_dist))
perm_file = os.path.join(
args.output_dir,
'hamming_perms_{}_patches_{}_{}.npy'.format(args.N, args.M, args.method)
)
logger.info('Writing permutations to: {}'.format(perm_file))
logger.info('permutations shape: {}'.format(selected_perms.shape))
np.save(perm_file, selected_perms)
logger.info('Done!')