in paq/retrievers/embed.py [0:0]
def embed(model, tokenizer, qas, bsz=256, cuda=CUDA, fp16=False):
def normalize_q(question: str) -> str:
return question.strip().strip('?').lower().strip()
def tokenize(batch_qas):
input_qs = [normalize_q(q['question']) for q in batch_qas]
inputs = tokenizer.batch_encode_plus(
input_qs, return_tensors='pt', padding=True, add_special_tokens=True
)
return {k: v.cuda() for k, v in inputs.items()} if cuda else inputs
if cuda:
model = model.cuda()
model = to_fp16(model) if fp16 else model
t = time.time()
def log_progress(j, outputs):
t2 = time.time()
logger.info(
f'Embedded {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds '
f'({sum([len(o) for o in outputs]) / (t2 - t): 0.4f} QAs per second)')
outputs = []
with torch.no_grad():
for j, batch_start in enumerate(range(0, len(qas), bsz)):
batch_qas = qas[batch_start: batch_start + bsz]
inputs = tokenize(batch_qas)
batch_outputs = model(**inputs)
outputs.append(batch_outputs.cpu())
if j % 10 == 0:
log_progress(j, outputs)
log_progress(j, outputs)
return torch.cat(outputs, dim=0).cpu()