in drqa/reader/model.py [0:0]
def decode_candidates(score_s, score_e, candidates, top_n=1, max_len=None):
"""Take argmax of constrained score_s * score_e. Except only consider
spans that are in the candidates list.
"""
pred_s = []
pred_e = []
pred_score = []
for i in range(score_s.size(0)):
# Extract original tokens stored with candidates
tokens = candidates[i]['input']
cands = candidates[i]['cands']
if not cands:
# try getting from globals? (multiprocessing in pipeline mode)
from ..pipeline.drqa import PROCESS_CANDS
cands = PROCESS_CANDS
if not cands:
raise RuntimeError('No candidates given.')
# Score all valid candidates found in text.
# Brute force get all ngrams and compare against the candidate list.
max_len = max_len or len(tokens)
scores, s_idx, e_idx = [], [], []
for s, e in tokens.ngrams(n=max_len, as_strings=False):
span = tokens.slice(s, e).untokenize()
if span in cands or span.lower() in cands:
# Match! Record its score.
scores.append(score_s[i][s] * score_e[i][e - 1])
s_idx.append(s)
e_idx.append(e - 1)
if len(scores) == 0:
# No candidates present
pred_s.append([])
pred_e.append([])
pred_score.append([])
else:
# Rank found candidates
scores = np.array(scores)
s_idx = np.array(s_idx)
e_idx = np.array(e_idx)
idx_sort = np.argsort(-scores)[0:top_n]
pred_s.append(s_idx[idx_sort])
pred_e.append(e_idx[idx_sort])
pred_score.append(scores[idx_sort])
return pred_s, pred_e, pred_score