in drqa/reader/model.py [0:0]
def predict(self, ex, candidates=None, top_n=1, async_pool=None):
"""Forward a batch of examples only to get predictions.
Args:
ex: the batch
candidates: batch * variable length list of string answer options.
The model will only consider exact spans contained in this list.
top_n: Number of predictions to return per batch element.
async_pool: If provided, non-gpu post-processing will be offloaded
to this CPU process pool.
Output:
pred_s: batch * top_n predicted start indices
pred_e: batch * top_n predicted end indices
pred_score: batch * top_n prediction scores
If async_pool is given, these will be AsyncResult handles.
"""
# Eval mode
self.network.eval()
# Transfer to GPU
if self.use_cuda:
inputs = [e if e is None else e.cuda(non_blocking=True)
for e in ex[:5]]
else:
inputs = [e for e in ex[:5]]
# Run forward
with torch.no_grad():
score_s, score_e = self.network(*inputs)
# Decode predictions
score_s = score_s.data.cpu()
score_e = score_e.data.cpu()
if candidates:
args = (score_s, score_e, candidates, top_n, self.args.max_len)
if async_pool:
return async_pool.apply_async(self.decode_candidates, args)
else:
return self.decode_candidates(*args)
else:
args = (score_s, score_e, top_n, self.args.max_len)
if async_pool:
return async_pool.apply_async(self.decode, args)
else:
return self.decode(*args)