def predict()

in retrieval_eval_bleu.py [0:0]


def predict(context, top_n=5, normalize=False):
    """
    returns a list of top_n tuples ("sentence", "score")
    """
    with torch.no_grad():
        context = context.unsqueeze(0)
        candidates = fixed_candidates
        if args.cuda:
            context = context.cuda(non_blocking=True)
        ctx, _ = net(context, None)
        scores, index = score_candidates(ctx, cand_embs, top_n, normalize)
        response = []
        outputs = []
        for i, (score, index) in enumerate(zip(scores.squeeze(0), index.squeeze(0)), 1):
            response.append((stringify(candidates[index]), float(score)))
            if index < breakingpt:
                outputs.append("EmpChat")
            elif index < breakingpt2:
                outputs.append("DailyDialog")
            else:
                outputs.append("Reddit")
        return response, outputs