in sagemaker/22_accelerate_sagemaker_examples/src/seq2seq/run_seq2seq_no_trainer.py [0:0]
def evaluate(args, model, metric, tokenizer, eval_dataloader, accelerator, max_length):
accelerator.print("starting evaluation")
count_printed = 0
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [[label.strip()] for label in labels]
return preds, labels
model.eval()
if args.val_max_target_length is None:
args.val_max_target_length = args.max_target_length
gen_kwargs = {
"max_length": args.val_max_target_length if args is not None else max_length,
"num_beams": args.num_beams,
"min_length": args.val_min_target_length,
"length_penalty": False,
"no_repeat_ngram_size": 3,
"encoder_no_repeat_ngram_size": 3,
"repetition_penalty": 1.2,
}
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
generated_tokens = accelerator.unwrap_model(model).generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
**gen_kwargs,
)
generated_tokens = accelerator.pad_across_processes(
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
)
labels = batch["labels"]
if not args.pad_to_max_length:
# If we did not pad to max length, we need to pad the labels too
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
generated_tokens, labels = accelerator.gather((generated_tokens, labels))
generated_tokens = generated_tokens.cpu().numpy()
labels = labels.cpu().numpy()
if args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
if count_printed < args.n_val_batch_generations:
logger.info("printing few sample generations and corresponding labels from eval set")
logger.info("prompt | generated | label")
decoded_prompts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False)
for prompt, generated_response, response in zip(decoded_prompts, decoded_preds, decoded_labels):
cleaned_prompt = prompt.replace("<pad>", "").strip()
logger.info(f"{cleaned_prompt} | {generated_response} | {response}")
count_printed += 1
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader) - 1:
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += len(decoded_labels)
metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)
result = metric.compute()
logger.info({"bleu": result["score"]})
accelerator.print("evaluation completed")
return result["score"]