def predict()

in paq/rerankers/rerank.py [0:0]


def predict(model, tokenizer, qas, cuda=CUDA, bsz=16, fp16=False, top_k=30):

    if cuda:
        model = model.cuda()
        model = to_fp16(model) if fp16 else model

    t = time.time()

    def log_progress(j, outputs):
        t2 = time.time()
        logger.info(
            f'Reranked {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds '
            f'({len(outputs) / (t2 - t): 0.4f} QAs per second)')

    def forward(inputs):
        logits = model(**inputs)[0]
        scores, inds = logits.topk(1, dim=1)
        scores, inds = scores.squeeze().tolist(), inds.squeeze().tolist()
        if padded_batch:
            scores, inds = scores[:-1], inds[:-1]
        return scores, inds

    outputs = []
    output_scores = []
    logger.info(f'Embedding {len(qas)} inputs in {len(list(range(0, len(qas), bsz)))} batches:')
    with torch.no_grad():
        for j, batch_start in enumerate(range(0, len(qas), bsz)):

            batch = qas[batch_start: batch_start + bsz]
            padded_batch = len(batch) == 1
            if padded_batch: # hack for batch size 1 issues
                batch = [batch[0],batch[0]]

            inputs = tokenize(tokenizer, batch, cuda, top_k)
            scores, inds = forward(inputs)

            outputs.extend(inds)
            output_scores.extend(scores)

            log_progress(j, outputs) if j % 1 == 0 else None

    log_progress(j, outputs)

    return get_output_format(qas, outputs, output_scores)