in drqa/pipeline/drqa.py [0:0]
def process_batch(self, queries, candidates=None, top_n=1, n_docs=5,
return_context=False):
"""Run a batch of queries (more efficient)."""
t0 = time.time()
logger.info('Processing %d queries...' % len(queries))
logger.info('Retrieving top %d docs...' % n_docs)
# Rank documents for queries.
if len(queries) == 1:
ranked = [self.ranker.closest_docs(queries[0], k=n_docs)]
else:
ranked = self.ranker.batch_closest_docs(
queries, k=n_docs, num_workers=self.num_workers
)
all_docids, all_doc_scores = zip(*ranked)
# Flatten document ids and retrieve text from database.
# We remove duplicates for processing efficiency.
flat_docids = list({d for docids in all_docids for d in docids})
did2didx = {did: didx for didx, did in enumerate(flat_docids)}
doc_texts = self.processes.map(fetch_text, flat_docids)
# Split and flatten documents. Maintain a mapping from doc (index in
# flat list) to split (index in flat list).
flat_splits = []
didx2sidx = []
for text in doc_texts:
splits = self._split_doc(text)
didx2sidx.append([len(flat_splits), -1])
for split in splits:
flat_splits.append(split)
didx2sidx[-1][1] = len(flat_splits)
# Push through the tokenizers as fast as possible.
q_tokens = self.processes.map_async(tokenize_text, queries)
s_tokens = self.processes.map_async(tokenize_text, flat_splits)
q_tokens = q_tokens.get()
s_tokens = s_tokens.get()
# Group into structured example inputs. Examples' ids represent
# mappings to their question, document, and split ids.
examples = []
for qidx in range(len(queries)):
for rel_didx, did in enumerate(all_docids[qidx]):
start, end = didx2sidx[did2didx[did]]
for sidx in range(start, end):
if (len(q_tokens[qidx].words()) > 0 and
len(s_tokens[sidx].words()) > 0):
examples.append({
'id': (qidx, rel_didx, sidx),
'question': q_tokens[qidx].words(),
'qlemma': q_tokens[qidx].lemmas(),
'document': s_tokens[sidx].words(),
'lemma': s_tokens[sidx].lemmas(),
'pos': s_tokens[sidx].pos(),
'ner': s_tokens[sidx].entities(),
})
logger.info('Reading %d paragraphs...' % len(examples))
# Push all examples through the document reader.
# We decode argmax start/end indices asychronously on CPU.
result_handles = []
num_loaders = min(self.max_loaders, math.floor(len(examples) / 1e3))
for batch in self._get_loader(examples, num_loaders):
if candidates or self.fixed_candidates:
batch_cands = []
for ex_id in batch[-1]:
batch_cands.append({
'input': s_tokens[ex_id[2]],
'cands': candidates[ex_id[0]] if candidates else None
})
handle = self.reader.predict(
batch, batch_cands, async_pool=self.processes
)
else:
handle = self.reader.predict(batch, async_pool=self.processes)
result_handles.append((handle, batch[-1], batch[0].size(0)))
# Iterate through the predictions, and maintain priority queues for
# top scored answers for each question in the batch.
queues = [[] for _ in range(len(queries))]
for result, ex_ids, batch_size in result_handles:
s, e, score = result.get()
for i in range(batch_size):
# We take the top prediction per split.
if len(score[i]) > 0:
item = (score[i][0], ex_ids[i], s[i][0], e[i][0])
queue = queues[ex_ids[i][0]]
if len(queue) < top_n:
heapq.heappush(queue, item)
else:
heapq.heappushpop(queue, item)
# Arrange final top prediction data.
all_predictions = []
for queue in queues:
predictions = []
while len(queue) > 0:
score, (qidx, rel_didx, sidx), s, e = heapq.heappop(queue)
prediction = {
'doc_id': all_docids[qidx][rel_didx],
'span': s_tokens[sidx].slice(s, e + 1).untokenize(),
'doc_score': float(all_doc_scores[qidx][rel_didx]),
'span_score': float(score),
}
if return_context:
prediction['context'] = {
'text': s_tokens[sidx].untokenize(),
'start': s_tokens[sidx].offsets()[s][0],
'end': s_tokens[sidx].offsets()[e][1],
}
predictions.append(prediction)
all_predictions.append(predictions[-1::-1])
logger.info('Processed %d queries in %.4f (s)' %
(len(queries), time.time() - t0))
return all_predictions