in src/lighteval/metrics/imports/bert_scorer.py [0:0]
def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False):
"""
Args:
- :param: `cands` (list of str): candidate sentences
- :param: `refs` (list of str or list of list of str): reference sentences
Return:
- :param: `(P, R, F)`: each is of shape (N); N = number of input
candidate reference pairs. if returning hashcode, the
output will be ((P, R, F), hashcode). If a candidate have
multiple references, the returned score of this candidate is
the *best* score among all references.
"""
if self._model is None:
logger.info(f"Loading BERTScorer model `{self._model_type}`")
self._tokenizer = AutoTokenizer.from_pretrained(self._model_type)
self._model = AutoModel.from_pretrained(self._model_type)
self._model.eval()
self._model.to(self.device)
ref_group_boundaries = None
if not isinstance(refs[0], str):
ref_group_boundaries = []
ori_cands, ori_refs = cands, refs
cands, refs = [], []
count = 0
for cand, ref_group in zip(ori_cands, ori_refs):
cands += [cand] * len(ref_group)
refs += ref_group
ref_group_boundaries.append((count, count + len(ref_group)))
count += len(ref_group)
if verbose:
logger.info("calculating scores...")
start = time.perf_counter()
if self.idf:
assert self._idf_dict, "IDF weights are not computed"
idf_dict = self._idf_dict
else:
idf_dict = defaultdict(lambda: 1.0)
idf_dict[self._tokenizer.sep_token_id] = 0
idf_dict[self._tokenizer.cls_token_id] = 0
all_preds = bert_cos_score_idf(
self._model,
refs,
cands,
self._tokenizer,
idf_dict,
verbose=verbose,
device=self.device,
batch_size=batch_size,
all_layers=self.all_layers,
).cpu()
if ref_group_boundaries is not None:
max_preds = []
for start, end in ref_group_boundaries:
max_preds.append(all_preds[start:end].max(dim=0)[0])
all_preds = torch.stack(max_preds, dim=0)
if self.rescale_with_baseline:
all_preds = (all_preds - self.baseline_vals) / (1 - self.baseline_vals)
out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F
if verbose:
time_diff = time.perf_counter() - start
logger.info(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec")
return out