def evaluate()

in sagemaker/22_accelerate_sagemaker_examples/src/seq2seq/run_seq2seq_no_trainer.py [0:0]


def evaluate(args, model, metric, tokenizer, eval_dataloader, accelerator, max_length):
    accelerator.print("starting evaluation")
    count_printed = 0

    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [[label.strip()] for label in labels]

        return preds, labels

    model.eval()
    if args.val_max_target_length is None:
        args.val_max_target_length = args.max_target_length

    gen_kwargs = {
        "max_length": args.val_max_target_length if args is not None else max_length,
        "num_beams": args.num_beams,
        "min_length": args.val_min_target_length,
        "length_penalty": False,
        "no_repeat_ngram_size": 3,
        "encoder_no_repeat_ngram_size": 3,
        "repetition_penalty": 1.2,
    }
    samples_seen = 0
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            generated_tokens = accelerator.unwrap_model(model).generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                **gen_kwargs,
            )

            generated_tokens = accelerator.pad_across_processes(
                generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
            )
            labels = batch["labels"]
            if not args.pad_to_max_length:
                # If we did not pad to max length, we need to pad the labels too
                labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

            generated_tokens, labels = accelerator.gather((generated_tokens, labels))
            generated_tokens = generated_tokens.cpu().numpy()
            labels = labels.cpu().numpy()

            if args.ignore_pad_token_for_loss:
                # Replace -100 in the labels as we can't decode them.
                labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            if count_printed < args.n_val_batch_generations:
                logger.info("printing few sample generations and corresponding labels from eval set")
                logger.info("prompt | generated | label")
                decoded_prompts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False)
                for prompt, generated_response, response in zip(decoded_prompts, decoded_preds, decoded_labels):
                    cleaned_prompt = prompt.replace("<pad>", "").strip()
                    logger.info(f"{cleaned_prompt} | {generated_response} | {response}")
                count_printed += 1

            decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
            # If we are in a multiprocess environment, the last batch has duplicates
            if accelerator.num_processes > 1:
                if step == len(eval_dataloader) - 1:
                    decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
                    decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
                else:
                    samples_seen += len(decoded_labels)

            metric.add_batch(
                predictions=decoded_preds,
                references=decoded_labels,
            )
    result = metric.compute()
    logger.info({"bleu": result["score"]})
    accelerator.print("evaluation completed")
    return result["score"]