def on_epoch_end()

in src/transformers/keras_callbacks.py [0:0]


    def on_epoch_end(self, epoch, logs=None):
        if hasattr(self.model, "config"):
            ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
        else:
            ignore_keys = []

        main_input_name = None
        if self.predict_with_generate:
            # This dense conditional recognizes the case where we have an encoder-decoder model, but
            # avoids getting tangled up when we just have a model with a layer called 'encoder'
            if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
                main_input_name = self.model.encoder.main_input_name
            else:
                main_input_name = getattr(self.model, "main_input_name", "input_ids")

            if self.use_xla_generation and self.generation_function is None:

                def generation_function(inputs, attention_mask):
                    return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)

                self.generation_function = tf.function(generation_function, jit_compile=True)

        prediction_list = []
        label_list = []

        # The whole predict/generate loop is handled inside this method
        for batch in self.eval_dataset:
            if isinstance(batch, tuple):
                batch, labels = batch
            else:
                labels = None
            if self.predict_with_generate:
                if isinstance(batch, dict):
                    generation_inputs = batch[main_input_name]
                    attention_mask = batch.get("attention_mask", None)
                else:
                    generation_inputs = batch
                    attention_mask = None
                if self.use_xla_generation:
                    predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
                else:
                    predictions = self.model.generate(
                        generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
                    )
            else:
                predictions = self.model.predict_on_batch(batch)
                if isinstance(predictions, dict):
                    # This converts any dict-subclass to a regular dict
                    # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
                    predictions = dict(predictions)
                    if self.output_cols is not None:
                        predictions = {key: predictions[key] for key in self.output_cols}
                    else:
                        predictions = {
                            key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]
                        }
            prediction_list.append(predictions)
            if not self.use_keras_label:
                labels = {key: batch[key].numpy() for key in self.label_cols}
            elif isinstance(labels, dict):
                labels = {key: array.numpy() for key, array in labels.items()}
            elif isinstance(labels, list) or isinstance(labels, tuple):
                labels = [array.numpy() for array in labels]
            elif isinstance(labels, tf.Tensor):
                labels = labels.numpy()
            else:
                raise TypeError(f"Confused by labels of type {type(labels)}")
            label_list.append(labels)

        all_preds = self._postprocess_predictions_or_labels(prediction_list)
        all_labels = self._postprocess_predictions_or_labels(label_list)

        metric_output = self.metric_fn((all_preds, all_labels))
        if not isinstance(metric_output, dict):
            raise TypeError(
                f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
            )
        # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch
        # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of
        # new keys in there, which will then get read by the History callback and treated like any other metric value.
        # I promise that I have it in writing from Chollet that this is okay.
        logs.update(metric_output)