in src/biencoder_predict_qa.py [0:0]
def run_prediction(roberta, heads, examples, args):
C_MAX = args.context_max_len
STRIDE = args.context_stride
assert C_MAX % 2 == 0
assert STRIDE % 2 == 0
# Run prediction in batch
pred_json = {}
i = 0
with tqdm(total=len(examples)) as pbar:
while i < len(examples):
# Greedily grab the next batch until batch_size and max_tokens are full
cur_batch = []
cur_tokens = 0
qid_to_batch_index = collections.defaultdict(list)
qid_to_full_bpe = {}
while i < len(examples) and len(cur_batch) < args.batch_size:
next_example = examples[i]
qid, p_bpe, q_bpe = examples[i]
if len(p_bpe) > C_MAX:
p_bpe_list = []
for j in range(0, len(p_bpe) - C_MAX + STRIDE, STRIDE):
cur_endpoint = min(len(p_bpe), j + C_MAX)
p_bpe_list.append(p_bpe[j:cur_endpoint])
cur_tokens = max(cur_tokens, C_MAX, len(q_bpe))
else:
p_bpe_list = [p_bpe]
cur_tokens = max(cur_tokens, len(p_bpe), len(q_bpe))
if (len(cur_batch) + len(p_bpe_list)) * cur_tokens > args.max_tokens:
# Wait until next batch to process this
break
# Add to the batch now
for p_bpe_chunk in p_bpe_list:
qid_to_batch_index[qid].append(len(cur_batch))
cur_batch.append((qid, p_bpe_chunk, q_bpe))
qid_to_full_bpe[qid] = p_bpe
i += 1
# Make prediction
with torch.no_grad():
p_tokens = collate_tokens([b[1] for b in cur_batch], pad_idx=1)
q_tokens = collate_tokens([b[2] for b in cur_batch], pad_idx=1)
p_features = roberta.model(p_tokens.to(device=roberta.device),
features_only=True,
return_all_hiddens=False)[0] # B, L, d
q_features = roberta.model(q_tokens.to(device=roberta.device),
features_only=True,
return_all_hiddens=False)[0] # B, L, d
q_cls = q_features[:,0,:] # B, d
start_vec = heads['start'](q_cls) # B, d
end_vec = heads['end'](q_cls) # B, d
start_logits = torch.matmul(p_features, start_vec.unsqueeze(-1))[:,:,0]
end_logits = torch.matmul(p_features, end_vec.unsqueeze(-1))[:,:,0]
for qid, batch_idxs in qid_to_batch_index.items():
if len(batch_idxs) == 1:
# Easy case
cur_start_logits = start_logits[batch_idxs[0],:]
cur_end_logits = end_logits[batch_idxs[0],:]
else:
cur_start_list = []
cur_end_list = []
for j in batch_idxs:
if j == batch_idxs[0]: # Case 1: first index
cur_start_list.append(start_logits[j,:C_MAX//2 + STRIDE//2])
cur_end_list.append(end_logits[j,:C_MAX//2 + STRIDE//2])
elif j == batch_idxs[-1]: # Case 2: last index (will truncate later)
cur_start_list.append(start_logits[j,C_MAX//2 - STRIDE//2:])
cur_end_list.append(end_logits[j,C_MAX//2 - STRIDE//2:])
else: # Case 3: middle index
cur_start_list.append(start_logits[j, C_MAX//2 - STRIDE//2: C_MAX//2 + STRIDE//2])
cur_end_list.append(end_logits[j, C_MAX//2 - STRIDE//2: C_MAX//2 + STRIDE//2])
cur_start_logits = torch.cat(cur_start_list)[:len(qid_to_full_bpe[qid])]
cur_end_logits = torch.cat(cur_end_list)[:len(qid_to_full_bpe[qid])]
start_probs = F.softmax(cur_start_logits, dim=0).tolist()
end_probs = F.softmax(cur_end_logits, dim=0).tolist()
best_start, best_end = find_span(start_probs, end_probs)
pred_str = roberta.decode(qid_to_full_bpe[qid][best_start:best_end + 1])
pred_json[qid] = pred_str
pbar.update(1)
with open(args.output_file, 'w') as f:
json.dump(pred_json, f)