in paq/rerankers/rerank.py [0:0]
def predict(model, tokenizer, qas, cuda=CUDA, bsz=16, fp16=False, top_k=30):
if cuda:
model = model.cuda()
model = to_fp16(model) if fp16 else model
t = time.time()
def log_progress(j, outputs):
t2 = time.time()
logger.info(
f'Reranked {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds '
f'({len(outputs) / (t2 - t): 0.4f} QAs per second)')
def forward(inputs):
logits = model(**inputs)[0]
scores, inds = logits.topk(1, dim=1)
scores, inds = scores.squeeze().tolist(), inds.squeeze().tolist()
if padded_batch:
scores, inds = scores[:-1], inds[:-1]
return scores, inds
outputs = []
output_scores = []
logger.info(f'Embedding {len(qas)} inputs in {len(list(range(0, len(qas), bsz)))} batches:')
with torch.no_grad():
for j, batch_start in enumerate(range(0, len(qas), bsz)):
batch = qas[batch_start: batch_start + bsz]
padded_batch = len(batch) == 1
if padded_batch: # hack for batch size 1 issues
batch = [batch[0],batch[0]]
inputs = tokenize(tokenizer, batch, cuda, top_k)
scores, inds = forward(inputs)
outputs.extend(inds)
output_scores.extend(scores)
log_progress(j, outputs) if j % 1 == 0 else None
log_progress(j, outputs)
return get_output_format(qas, outputs, output_scores)