in engine/training_engine.py [0:0]
def train_epoch(self, epoch):
time.sleep(2) # To prevent possible deadlock during epoch transition
train_stats = Statistics(metric_names=self.metric_names, is_master_node=self.is_master_node)
self.model.train()
accum_freq = self.accum_freq if epoch > self.accum_after_epoch else 1
max_norm = getattr(self.opts, "common.grad_clip", None)
self.optimizer.zero_grad()
epoch_start_time = time.time()
batch_load_start = time.time()
for batch_id, batch in enumerate(self.train_loader):
if self.train_iterations > self.max_iterations:
self.max_iterations_reached = True
return -1, -1
batch_load_toc = time.time() - batch_load_start
input_img, target_label = batch['image'], batch['label']
# move data to device
input_img = input_img.to(self.device)
if isinstance(target_label, Dict):
for k, v in target_label.items():
target_label[k] = v.to(self.device)
else:
target_label = target_label.to(self.device)
batch_size = input_img.shape[0]
# update the learning rate
self.optimizer = self.scheduler.update_lr(optimizer=self.optimizer, epoch=epoch,
curr_iter=self.train_iterations)
# adjust bn momentum
if self.adjust_norm_mom is not None:
self.adjust_norm_mom.adjust_momentum(model=self.model,
epoch=epoch,
iteration=self.train_iterations)
with autocast(enabled=self.mixed_precision_training):
# prediction
pred_label = self.model(input_img)
# compute loss
loss = self.criteria(input_sample=input_img, prediction=pred_label, target=target_label)
if isinstance(loss, torch.Tensor) and torch.isnan(loss):
import pdb
pdb.set_trace()
# perform the backward pass with gradient accumulation [Optional]
self.gradient_scalar.scale(loss).backward()
if (batch_id + 1) % accum_freq == 0:
if max_norm is not None:
# For gradient clipping, unscale the gradients and then clip them
self.gradient_scalar.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=max_norm)
# optimizer step
self.gradient_scalar.step(optimizer=self.optimizer)
# update the scale for next batch
self.gradient_scalar.update()
# set grads to zero
self.optimizer.zero_grad()
if self.model_ema is not None:
self.model_ema.update_parameters(self.model)
metrics = metric_monitor(pred_label=pred_label, target_label=target_label, loss=loss,
use_distributed=self.use_distributed, metric_names=self.metric_names)
train_stats.update(metric_vals=metrics, batch_time=batch_load_toc, n=batch_size)
if batch_id % self.log_freq == 0 and self.is_master_node:
lr = self.scheduler.retrieve_lr(self.optimizer)
train_stats.iter_summary(epoch=epoch,
n_processed_samples=self.train_iterations,
total_samples=self.max_iterations,
learning_rate=lr,
elapsed_time=epoch_start_time)
batch_load_start = time.time()
self.train_iterations += 1
avg_loss = train_stats.avg_statistics(metric_name='loss')
train_stats.epoch_summary(epoch=epoch, stage="training")
avg_ckpt_metric = train_stats.avg_statistics(metric_name=self.ckpt_metric)
return avg_loss, avg_ckpt_metric