in src/lighteval/metrics/imports/summac.py [0:0]
def build_image(self, original, generated):
cache_key = (original, generated)
if self.use_cache and cache_key in self.cache:
cached_image = self.cache[cache_key]
cached_image = cached_image[:, : self.max_doc_sents, :]
return cached_image
if len(self.grans) == 1:
gran_doc, gran_sum = self.grans[0], self.grans[0]
else:
gran_doc, gran_sum = self.grans[0], self.grans[1]
original_chunks = self.split_text(original, granularity=gran_doc)[: self.max_doc_sents]
generated_chunks = self.split_text(generated, granularity=gran_sum)
N_ori = len(original_chunks)
N_gen = len(generated_chunks)
if N_ori == 0 or N_gen == 0:
return np.zeros((3, 1, 1))
# assert (N_ori > 0 and N_gen > 0), "One of the inputs has no chunks"
image = np.zeros((3, N_ori, N_gen))
if self.model is None:
self.load_nli()
dataset = [
{"premise": original_chunks[i], "hypothesis": generated_chunks[j], "doc_i": i, "gen_i": j}
for i in range(N_ori)
for j in range(N_gen)
]
for batch in batcher(dataset, batch_size=20):
if self.model_name == "decomp":
batch_evids, batch_conts, batch_neuts = [], [], []
batch_json = [{"premise": d["premise"], "hypothesis": d["hypothesis"]} for d in batch]
model_outs = self.model.predict_batch_json(batch_json)
for out in model_outs:
probs = out["label_probs"]
batch_evids.append(probs[0])
batch_conts.append(probs[1])
batch_neuts.append(probs[2])
else:
batch_prems = [b["premise"] for b in batch]
batch_hypos = [b["hypothesis"] for b in batch]
batch_tokens = self.tokenizer.batch_encode_plus(
list(zip(batch_prems, batch_hypos)),
padding=True,
truncation=True,
max_length=self.max_input_length,
return_tensors="pt",
truncation_strategy="only_first",
)
batch_tokens = {k: v.to(self.device) for k, v in batch_tokens.items()}
with torch.no_grad():
model_outputs = self.model(**batch_tokens)
batch_probs = torch.nn.functional.softmax(model_outputs["logits"], dim=-1)
batch_evids = batch_probs[:, self.entailment_idx].tolist()
batch_conts = batch_probs[:, self.contradiction_idx].tolist()
batch_neuts = batch_probs[:, self.neutral_idx].tolist()
for b, evid, cont, neut in zip(batch, batch_evids, batch_conts, batch_neuts):
image[0, b["doc_i"], b["gen_i"]] = evid
image[1, b["doc_i"], b["gen_i"]] = cont
image[2, b["doc_i"], b["gen_i"]] = neut
if self.use_cache:
self.cache[cache_key] = image
return image