in retrieval_eval_bleu.py [0:0]
def predict(context, top_n=5, normalize=False):
"""
returns a list of top_n tuples ("sentence", "score")
"""
with torch.no_grad():
context = context.unsqueeze(0)
candidates = fixed_candidates
if args.cuda:
context = context.cuda(non_blocking=True)
ctx, _ = net(context, None)
scores, index = score_candidates(ctx, cand_embs, top_n, normalize)
response = []
outputs = []
for i, (score, index) in enumerate(zip(scores.squeeze(0), index.squeeze(0)), 1):
response.append((stringify(candidates[index]), float(score)))
if index < breakingpt:
outputs.append("EmpChat")
elif index < breakingpt2:
outputs.append("DailyDialog")
else:
outputs.append("Reddit")
return response, outputs