in src/jobs/tune_t5.py [0:0]
def run_eval(self, eval_dataset, name="Eval", prefix=None, log_wandb=True):
results_labels = []
results_output = []
losses = []
for item in eval_dataset:
input = item["input_text"]
label = item['target_text']
input_encodings = self.tokenizer(input, return_tensors="pt", truncation=True,
padding=True)
label_encodings = self.tokenizer(label, return_tensors="pt", truncation=True,
padding=True)
input_ids = input_encodings["input_ids"].to(self.device)
attention_mask = input_encodings["attention_mask"].to(self.device)
label_ids = label_encodings["input_ids"].to(self.device)
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=label_ids
)
loss = outputs.loss.item()
outputs = self.model.generate(input_ids, max_length=30, num_return_sequences=1)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
results_output.append(response)
results_labels.append(label)
losses.append(loss)
metrics = self.compute_metrics_text(results_output, results_labels, prefix=prefix)
if (log_wandb):
wandb.log(metrics)
table = wandb.Table(
columns=["input", "label", "prediction", "loss"],
data=list(
zip(
[d["input_text"] for d in eval_dataset],
results_labels,
results_output,
losses
)
),
)
wandb.log({f"{name} Set": table})
else:
print(f"{name} stats {json.dumps(metrics)}")