in src/transformers/keras_callbacks.py [0:0]
def on_epoch_end(self, epoch, logs=None):
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
main_input_name = None
if self.predict_with_generate:
# This dense conditional recognizes the case where we have an encoder-decoder model, but
# avoids getting tangled up when we just have a model with a layer called 'encoder'
if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
main_input_name = self.model.encoder.main_input_name
else:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
if self.use_xla_generation and self.generation_function is None:
def generation_function(inputs, attention_mask):
return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)
self.generation_function = tf.function(generation_function, jit_compile=True)
prediction_list = []
label_list = []
# The whole predict/generate loop is handled inside this method
for batch in self.eval_dataset:
if isinstance(batch, tuple):
batch, labels = batch
else:
labels = None
if self.predict_with_generate:
if isinstance(batch, dict):
generation_inputs = batch[main_input_name]
attention_mask = batch.get("attention_mask", None)
else:
generation_inputs = batch
attention_mask = None
if self.use_xla_generation:
predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
else:
predictions = self.model.generate(
generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
)
else:
predictions = self.model.predict_on_batch(batch)
if isinstance(predictions, dict):
# This converts any dict-subclass to a regular dict
# Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
predictions = dict(predictions)
if self.output_cols is not None:
predictions = {key: predictions[key] for key in self.output_cols}
else:
predictions = {
key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]
}
prediction_list.append(predictions)
if not self.use_keras_label:
labels = {key: batch[key].numpy() for key in self.label_cols}
elif isinstance(labels, dict):
labels = {key: array.numpy() for key, array in labels.items()}
elif isinstance(labels, list) or isinstance(labels, tuple):
labels = [array.numpy() for array in labels]
elif isinstance(labels, tf.Tensor):
labels = labels.numpy()
else:
raise TypeError(f"Confused by labels of type {type(labels)}")
label_list.append(labels)
all_preds = self._postprocess_predictions_or_labels(prediction_list)
all_labels = self._postprocess_predictions_or_labels(label_list)
metric_output = self.metric_fn((all_preds, all_labels))
if not isinstance(metric_output, dict):
raise TypeError(
f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
)
# This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch
# in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of
# new keys in there, which will then get read by the History callback and treated like any other metric value.
# I promise that I have it in writing from Chollet that this is okay.
logs.update(metric_output)