def embed()

in paq/retrievers/embed.py [0:0]


def embed(model, tokenizer, qas, bsz=256, cuda=CUDA, fp16=False):

    def normalize_q(question: str) -> str:
        return question.strip().strip('?').lower().strip()

    def tokenize(batch_qas):
        input_qs = [normalize_q(q['question']) for q in batch_qas]
        inputs = tokenizer.batch_encode_plus(
            input_qs, return_tensors='pt', padding=True, add_special_tokens=True
        )
        return {k: v.cuda() for k, v in inputs.items()} if cuda else inputs

    if cuda:
        model = model.cuda()
        model = to_fp16(model) if fp16 else model

    t = time.time()

    def log_progress(j, outputs):
        t2 = time.time()
        logger.info(
            f'Embedded {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds '
            f'({sum([len(o) for o in outputs]) / (t2 - t): 0.4f} QAs per second)')

    outputs = []
    with torch.no_grad():
        for j, batch_start in enumerate(range(0, len(qas), bsz)):
            batch_qas = qas[batch_start: batch_start + bsz]
            inputs = tokenize(batch_qas)
            batch_outputs = model(**inputs)
            outputs.append(batch_outputs.cpu())
            if j % 10 == 0:
                log_progress(j, outputs)

    log_progress(j, outputs)

    return torch.cat(outputs, dim=0).cpu()