in pytorch-transformers/pseudoalignment/pseudo_decomp_variable.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--split", default='dev', choices=['train', 'dev'], type=str,
help="Find NNs for which HotpotQA split?")
parser.add_argument("--min_q_len", default=4, type=int,
help="Minimum number of spacy tokens allowed in a SQuAD question")
parser.add_argument("--beam_size", default=100, type=int,
help="Top-K most similar questions to comp Q to consider")
parser.add_argument("--max_q_len", default=20, type=int,
help="Max number of spacy tokens allowed in a SQuAD question")
parser.add_argument("--n_mined", default=0, type=int,
help="Number of mined questions to use")
parser.add_argument("--target_question_length", type=int, default=2)
parser.add_argument("--data_folder", type=str, help="Path to save results to")
args = parser.parse_args()
args.lower_case = False
args.use_mined_questions = False
nlp = load_nlp()
raw_hard_qs = load_compositional_raw_qs(args.split)
# Load questions and convert to L2-normalized vectors
raw_qs, qs = load_single_hop_qs(nlp, args.n_mined, args.min_q_len,
args.max_q_len, strip_qmarks=False,
lower_case=args.lower_case, use_mined_questions=args.use_mined_questions)
raw_qs_stripped, qs_stripped = load_single_hop_qs(nlp, args.n_mined, args.min_q_len, args.max_q_len,
strip_qmarks=True, lower_case=args.lower_case)
normed_qs = normalize(qs, axis=1, norm='l2')
candidate_index = faiss.IndexFlatIP(300)
candidate_index.add(normed_qs)
solutions = []
solution_scores = []
def nn_search(query, search_space, cached_square_norms=None):
if cached_square_norms is None:
cached_square_norms = np.sum(search_space, axis=1)
return np.sqrt(np.sum(query ** 2) + cached_square_norms - 2 * np.dot(search_space, query[0]))
all_cached_square_norms = np.sum(qs_stripped**2, axis=1)
k = args.beam_size
faiss_batch_size = 500
for bn in tqdm(range(0, len(raw_hard_qs), faiss_batch_size)):
raw_hard_qs_batch = raw_hard_qs[bn: bn + faiss_batch_size]
fbs = len(raw_hard_qs_batch)
hard_q_stripped_batch = np.zeros((fbs, 300), dtype=np.float32)
normed_hard_q_batch = np.zeros((fbs, 300), dtype=np.float32)
for i, raw_hard_q in enumerate(raw_hard_qs_batch):
_, _, hard_q, _ = vectorize(nlp, raw_hard_q, strip_qmarks=False, lowercase=args.lower_case)
_, _, hard_q_stripped, _ = vectorize(nlp, raw_hard_q, strip_qmarks=True, lowercase=args.lower_case)
normed_hard_q = normalize(hard_q[np.newaxis, :], axis=1, norm='l2')
hard_q_stripped_batch[i] = hard_q_stripped
normed_hard_q_batch[i] = normed_hard_q
Dts, Its = candidate_index.search(normed_hard_q_batch, k)
for i, raw_hard_q in enumerate(raw_hard_qs_batch):
hard_q_stripped = hard_q_stripped_batch[i: i+1]
Dt, It = Dts[i: i+1], Its[i: i+1]
qs_stripped_top_k = qs_stripped[It[0]]
cached_square_norms = all_cached_square_norms[It[0]]
zero_hop_scores = nn_search(hard_q_stripped, qs_stripped_top_k, cached_square_norms=cached_square_norms)
solution_list = np.zeros((k, args.target_question_length + 1), dtype=np.int32) # (N, )
solution_list[:, 0] = It[0]
solution_scores_list = np.zeros((k, args.target_question_length + 1))
solution_scores_list[:, 0] = zero_hop_scores
old_solution_vecs = qs_stripped_top_k.copy()
n_solution_iters = 1
while n_solution_iters < (args.target_question_length + 1):
for solution_index in range(k):
old_solution_vec = old_solution_vecs[solution_index]
query = hard_q_stripped - old_solution_vec
total_scores = nn_search(query, qs_stripped_top_k, cached_square_norms=cached_square_norms)
total_scores[solution_index] = 1e6
best_temp = np.argmin(total_scores)
best = It[0, best_temp]
best_score = total_scores[best_temp]
solution_list[solution_index, n_solution_iters] = best
solution_scores_list[solution_index, n_solution_iters] = best_score
old_solution_vecs[solution_index] += qs_stripped_top_k[best_temp]
n_solution_iters += 1
# prevent ones
solution_scores_list[:, 0] = 1e9
best_i, best_j = np.unravel_index(solution_scores_list.argmin(), solution_scores_list.shape)
best_sol, best_score = solution_list[best_i, :best_j + 1], solution_scores_list[best_i, best_j]
solutions.append([raw_qs[i] for i in best_sol])
solution_scores.append(best_score)
print('Saving to file...')
os.makedirs(args.data_dir, exist_ok=True)
save_split = 'valid' if args.split == 'dev' else 'train'
with open(f'{args.data_dir}/{save_split}.sh', 'w') as f:
f.writelines('\n'.join([' '.join(sqs) for sqs in solutions]) + '\n')
with open(f'{args.data_dir}/{save_split}.lens', 'w') as f:
f.writelines('\n'.join([str(len(sqs)) for sqs in solutions]) + '\n')
with open(f'{args.data_dir}/{save_split}.scores', 'w') as f:
f.writelines('\n'.join([str(s) for s in solution_scores]) + '\n')
with open(f'{args.data_dir}/{save_split}.mh', 'w') as f:
f.writelines('\n'.join(raw_hard_qs) + '\n')
print(f'Done! Saved to {args.data_dir}/{save_split}.sh mh')