in pytorch-transformers/pseudoalignment/pseudo_decomp_bert_nsp.py [0:0]
def main_bert_nsp():
parser = argparse.ArgumentParser()
parser.add_argument("--split", default='dev', choices=['train', 'dev'], type=str,
help="Find NNs for which HotpotQA split?")
parser.add_argument("--min_q_len", default=4, type=int,
help="Minimum number of spacy tokens allowed in a SQuAD question")
parser.add_argument("--max_q_len", default=20, type=int,
help="Max number of spacy tokens allowed in a SQuAD question")
parser.add_argument("--layer", default=0, type=int,
help="layer of bert")
parser.add_argument("--top_k_to_search", default=100, type=int)
parser.add_argument("--top_k_to_pairwise_rank", default=20, type=int)
parser.add_argument("--batch_num", type=int, help="which batch to compute")
parser.add_argument("--data_folder", type=str, help="Path to save results to")
args = parser.parse_args()
# Load questions and convert to L2-normalized vectors
print('Loading easy questions...')
with open(get_squad_path('train')) as f:
data_easy = json.load(f)
qs_all = np.load(get_bert_embedding_path('squad', 'train', args.layer))
qs_inds = []
raw_qs = []
c = 0
for article in data_easy['data']:
for paragraph in article['paragraphs']:
for qa in paragraph['qas']:
raw_q = qa['question'].strip()
if '?' not in raw_q:
raw_q += '?'
if args.min_q_len <= len(raw_q.split()) <= args.max_q_len:
raw_qs.append(raw_q)
qs_inds.append(c)
c += 1
qs = qs_all[qs_inds]
print('Loading hard questions...')
qs_hard = np.load(get_bert_embedding_path('hotpotqa', args.split, args.layer))
with open(get_hotpot_path(args.split)) as f:
data_hard = json.load(f)
raw_qs_hard = []
for qa in data_hard:
raw_q_hard = qa['question'].strip()
raw_qs_hard.append(raw_q_hard)
print(f'Loaded {len(qs_hard)} hard questions!')
print('Indexing easy Qs...')
index = faiss.IndexFlatIP(qs.shape[1]) # L2 Norm then Inner Product == Cosine Similarity
normed_qs = normalize(qs, axis=1, norm='l2')
normed_qs_hard = normalize(qs_hard, axis=1, norm='l2')
index.add(normed_qs)
print(f'Total Qs indexed: {index.ntotal}')
tokenizer = TOKENIZER_CLASS.from_pretrained(PRETRAINED_WEIGHTS)
model = MODEL_CLASS.from_pretrained(PRETRAINED_WEIGHTS, torchscript=True)
model.eval()
model.cuda()
if DEBUG:
raw_qs_hard = raw_qs_hard[:50]
else:
raw_qs_hard = raw_qs_hard[args.batch_num * 1000: args.batch_num * 1000 + 1000]
mh_tokens_cache = {}
for mh in raw_qs_hard:
tokenized_mh = tokenizer.encode(mh, add_special_tokens=True)
mh_tokens_cache[mh] = tokenized_mh
k = args.top_k_to_search
k2 = args.top_k_to_pairwise_rank
def build_batch(long_q, q_batch):
tokenized_mh = mh_tokens_cache[long_q]
input_ids = [
tokenized_mh + tokenizer.encode(q, add_special_tokens=True)[1:]
for q in q_batch
]
input_ids = [t[:512] for t in input_ids]
max_l = max([len(inp) for inp in input_ids])
input_tensor = torch.tensor([
inp + [tokenizer.pad_token_id for _ in range(max_l - len(inp))]
for inp in input_ids]
)
segment_tensor = torch.ones(input_tensor.shape, dtype=torch.int64)
segment_tensor[:, :len(tokenized_mh)] = 0
attention_mask = (input_tensor == tokenizer.pad_token_id)
attention_mask = 1 - attention_mask.to(torch.int64)
return input_tensor, segment_tensor, attention_mask
nsp_sh1 = []
nsp_sh2 = []
nsp_pairs = []
with torch.no_grad():
for mh_ind, mh in enumerate(raw_qs_hard):
if mh_ind % 10 == 0:
logging.info(f'Completed {mh_ind} questions')
single_score_map = {}
Dt, It = index.search(normed_qs_hard[mh_ind:mh_ind + 1], k)
raw_qs_to_search = [raw_qs[ind] for ind in It[0]]
for b in range(0, len(raw_qs_to_search), BATCH_SIZE):
short_q_batch = raw_qs_to_search[b:b + BATCH_SIZE]
input_ids, segments, att_mask = build_batch(mh, short_q_batch)
out = model(
input_ids=input_ids.cuda(),
token_type_ids=segments.cuda(),
attention_mask=att_mask.cuda()
)
batch_scores = torch.nn.functional.softmax(out[0], dim=1)[:, 0].cpu()
for sq, c in zip(short_q_batch, list(batch_scores)):
single_score_map[sq] = c
raw_q_to_search_2, _ = zip(*list(sorted(single_score_map.items(), key=lambda x: -x[1]))[:k2])
pairs_to_search = []
for sha in raw_q_to_search_2:
for shb in raw_q_to_search_2:
if sha != shb:
pairs_to_search.append((sha, shb))
pair_score_map = {}
for b in range(0, len(pairs_to_search), BATCH_SIZE):
pair_batch = pairs_to_search[b:b + BATCH_SIZE]
input_ids, segments, att_mask = build_batch(mh, [a + ' ' + b for a, b in pair_batch])
out = model(
input_ids=input_ids.cuda(),
token_type_ids=segments.cuda(),
attention_mask=att_mask.cuda()
)
batch_scores = torch.nn.functional.softmax(out[0], dim=1)[:, 0].cpu()
for pair, pair_score in zip(pair_batch, list(batch_scores)):
pair_score_map[pair] = pair_score
final_scores = []
for sha, shb in pairs_to_search:
final_scores.append((sha, shb, single_score_map[sha] + single_score_map[shb] + pair_score_map[(sha, shb)]))
sha, shb, score = max(final_scores, key=lambda x: x[-1])
nsp_sh1.append(sha)
nsp_sh2.append(shb)
nsp_pairs.append(sha + ' ' + shb)
print('Saving to file...')
save_split = 'valid' if args.split == 'dev' else 'train'
dump_pseudoalignments(args.data_folder, save_split, raw_qs_hard, nsp_sh1, nsp_sh2, nsp_pairs)