in egg/core/trainers.py [0:0]
def train(self, n_epochs):
for callback in self.callbacks:
callback.on_train_begin(self)
for epoch in range(self.start_epoch, n_epochs):
for callback in self.callbacks:
callback.on_epoch_begin(epoch + 1)
train_loss, train_interaction = self.train_epoch()
for callback in self.callbacks:
callback.on_epoch_end(train_loss, train_interaction, epoch + 1)
validation_loss = validation_interaction = None
if (
self.validation_data is not None
and self.validation_freq > 0
and (epoch + 1) % self.validation_freq == 0
):
for callback in self.callbacks:
callback.on_validation_begin(epoch + 1)
validation_loss, validation_interaction = self.eval()
for callback in self.callbacks:
callback.on_validation_end(
validation_loss, validation_interaction, epoch + 1
)
if self.should_stop:
for callback in self.callbacks:
callback.on_early_stopping(
train_loss,
train_interaction,
epoch + 1,
validation_loss,
validation_interaction,
)
break
for callback in self.callbacks:
callback.on_train_end()