def main()

in pytorch-transformers/pseudoalignment/pseudo_decomp_variable.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--split", default='dev', choices=['train', 'dev'], type=str,
                        help="Find NNs for which HotpotQA split?")
    parser.add_argument("--min_q_len", default=4, type=int,
                        help="Minimum number of spacy tokens allowed in a SQuAD question")
    parser.add_argument("--beam_size", default=100, type=int,
                        help="Top-K most similar questions to comp Q to consider")
    parser.add_argument("--max_q_len", default=20, type=int,
                        help="Max number of spacy tokens allowed in a SQuAD question")
    parser.add_argument("--n_mined", default=0, type=int,
                        help="Number of mined questions to use")
    parser.add_argument("--target_question_length", type=int, default=2)
    parser.add_argument("--data_folder", type=str, help="Path to save results to")
    args = parser.parse_args()
    args.lower_case = False
    args.use_mined_questions = False

    nlp = load_nlp()

    raw_hard_qs = load_compositional_raw_qs(args.split)

    # Load questions and convert to L2-normalized vectors
    raw_qs, qs = load_single_hop_qs(nlp, args.n_mined, args.min_q_len,
                                    args.max_q_len, strip_qmarks=False,
                                    lower_case=args.lower_case, use_mined_questions=args.use_mined_questions)
    raw_qs_stripped, qs_stripped = load_single_hop_qs(nlp, args.n_mined, args.min_q_len, args.max_q_len,
                                                      strip_qmarks=True, lower_case=args.lower_case)

    normed_qs = normalize(qs, axis=1, norm='l2')
    candidate_index = faiss.IndexFlatIP(300)
    candidate_index.add(normed_qs)

    solutions = []
    solution_scores = []

    def nn_search(query, search_space, cached_square_norms=None):
        if cached_square_norms is None:
            cached_square_norms = np.sum(search_space, axis=1)
        return np.sqrt(np.sum(query ** 2) + cached_square_norms - 2 * np.dot(search_space, query[0]))

    all_cached_square_norms = np.sum(qs_stripped**2, axis=1)
    k = args.beam_size
    faiss_batch_size = 500

    for bn in tqdm(range(0, len(raw_hard_qs), faiss_batch_size)):
        raw_hard_qs_batch = raw_hard_qs[bn: bn + faiss_batch_size]
        fbs = len(raw_hard_qs_batch)

        hard_q_stripped_batch = np.zeros((fbs, 300), dtype=np.float32)
        normed_hard_q_batch = np.zeros((fbs, 300), dtype=np.float32)

        for i, raw_hard_q in enumerate(raw_hard_qs_batch):
            _, _, hard_q, _ = vectorize(nlp, raw_hard_q, strip_qmarks=False, lowercase=args.lower_case)
            _, _, hard_q_stripped, _ = vectorize(nlp, raw_hard_q, strip_qmarks=True, lowercase=args.lower_case)
            normed_hard_q = normalize(hard_q[np.newaxis, :], axis=1, norm='l2')
            hard_q_stripped_batch[i] = hard_q_stripped
            normed_hard_q_batch[i] = normed_hard_q

        Dts, Its = candidate_index.search(normed_hard_q_batch, k)

        for i, raw_hard_q in enumerate(raw_hard_qs_batch):
            hard_q_stripped = hard_q_stripped_batch[i: i+1]
            Dt, It = Dts[i: i+1], Its[i: i+1]

            qs_stripped_top_k = qs_stripped[It[0]]
            cached_square_norms = all_cached_square_norms[It[0]]

            zero_hop_scores = nn_search(hard_q_stripped, qs_stripped_top_k, cached_square_norms=cached_square_norms)

            solution_list = np.zeros((k, args.target_question_length + 1), dtype=np.int32)  # (N, )
            solution_list[:, 0] = It[0]
            solution_scores_list = np.zeros((k, args.target_question_length + 1))
            solution_scores_list[:, 0] = zero_hop_scores
            old_solution_vecs = qs_stripped_top_k.copy()

            n_solution_iters = 1
            while n_solution_iters < (args.target_question_length + 1):
                for solution_index in range(k):
                    old_solution_vec = old_solution_vecs[solution_index]
                    query = hard_q_stripped - old_solution_vec
                    total_scores = nn_search(query, qs_stripped_top_k, cached_square_norms=cached_square_norms)
                    total_scores[solution_index] = 1e6

                    best_temp = np.argmin(total_scores)
                    best = It[0, best_temp]
                    best_score = total_scores[best_temp]

                    solution_list[solution_index, n_solution_iters] = best
                    solution_scores_list[solution_index, n_solution_iters] = best_score
                    old_solution_vecs[solution_index] += qs_stripped_top_k[best_temp]

                n_solution_iters += 1

            # prevent ones
            solution_scores_list[:, 0] = 1e9

            best_i, best_j = np.unravel_index(solution_scores_list.argmin(), solution_scores_list.shape)
            best_sol, best_score = solution_list[best_i, :best_j + 1], solution_scores_list[best_i, best_j]
            solutions.append([raw_qs[i] for i in best_sol])
            solution_scores.append(best_score)

    print('Saving to file...')
    os.makedirs(args.data_dir, exist_ok=True)
    save_split = 'valid' if args.split == 'dev' else 'train'
    with open(f'{args.data_dir}/{save_split}.sh', 'w') as f:
        f.writelines('\n'.join([' '.join(sqs) for sqs in solutions]) + '\n')
    with open(f'{args.data_dir}/{save_split}.lens', 'w') as f:
        f.writelines('\n'.join([str(len(sqs)) for sqs in solutions]) + '\n')
    with open(f'{args.data_dir}/{save_split}.scores', 'w') as f:
        f.writelines('\n'.join([str(s) for s in solution_scores]) + '\n')
    with open(f'{args.data_dir}/{save_split}.mh', 'w') as f:
        f.writelines('\n'.join(raw_hard_qs) + '\n')
    print(f'Done! Saved to {args.data_dir}/{save_split}.sh mh')