in egg/core/trainers.py [0:0]
def train_epoch(self):
mean_loss = 0
n_batches = 0
interactions = []
self.game.train()
self.optimizer.zero_grad()
for batch_id, batch in enumerate(self.train_data):
if not isinstance(batch, Batch):
batch = Batch(*batch)
batch = batch.to(self.device)
context = autocast() if self.scaler else nullcontext()
with context:
optimized_loss, interaction = self.game(*batch)
if self.update_freq > 1:
# throughout EGG, we minimize _mean_ loss, not sum
# hence, we need to account for that when aggregating grads
optimized_loss = optimized_loss / self.update_freq
if self.scaler:
self.scaler.scale(optimized_loss).backward()
else:
optimized_loss.backward()
if batch_id % self.update_freq == self.update_freq - 1:
if self.scaler:
self.scaler.unscale_(self.optimizer)
if self.grad_norm:
torch.nn.utils.clip_grad_norm_(
self.game.parameters(), self.grad_norm
)
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
n_batches += 1
mean_loss += optimized_loss.detach()
if (
self.distributed_context.is_distributed
and self.aggregate_interaction_logs
):
interaction = Interaction.gather_distributed_interactions(interaction)
interaction = interaction.to("cpu")
for callback in self.callbacks:
callback.on_batch_end(interaction, optimized_loss, batch_id)
interactions.append(interaction)
if self.optimizer_scheduler:
self.optimizer_scheduler.step()
mean_loss /= n_batches
full_interaction = Interaction.from_iterable(interactions)
return mean_loss.item(), full_interaction