in src/train.py [0:0]
def train(train_config):
loss = None
best_val_loss = 1 if train_config.best_valid_loss is None else train_config.best_valid_loss
c_file = train_config.config_file
tqdm_range = tqdm(range(train_config.epoch0, train_config.epochs))
# creating the iterator starts the workers (if any)
batch_iterator = iter(train_config.train_data_loader)
num_batches = int(math.ceil(len(train_config.train_dataset) / c_file.batchImages))
decay_rate = train_config.config_file.lrate_decay
decay_steps = train_config.config_file.lrate_decay_steps
pre_train_epochs = 0
if train_config.config_file.epochsPretrain is not None and len(train_config.config_file.epochsPretrain) != 0:
pre_train_epochs = max(train_config.config_file.epochsPretrain)
for epoch in tqdm_range:
for optim in train_config.optimizers:
optim.zero_grad()
# get sample input data
samples = next(batch_iterator)
sample_data = create_sample_wrapper(samples, train_config)
# we create a new iterator once all elements have been exhausted
# we get a slight performance bump by creating the iterator here and not in the beginning of the next loop
if epoch % num_batches == 0:
batch_iterator = iter(train_config.train_data_loader)
# inference
outs, inference_dicts = train_config.inference(sample_data, gradient=True)
# train net
for out_idx, criterion in enumerate(train_config.losses):
if criterion is None or train_config.loss_weights[out_idx] == 0 or \
train_config.weights_locked(epoch, out_idx):
continue
y_batch = sample_data.get_train_target(out_idx)
y_batch = y_batch.reshape(y_batch.shape[0] * c_file.samples, -1)
out = outs[out_idx]
inference_dict = inference_dicts[out_idx]
if len(y_batch.shape) == 3:
y_batch = y_batch[:, 0]
loss = criterion(out, y_batch, inference_dict=inference_dict) * train_config.loss_weights[out_idx]
loss.backward(retain_graph=out_idx < len(outs) - 1)
for i, optim in enumerate(train_config.optimizers):
if not train_config.weights_locked(epoch, i):
optim.step()
# Learning rate decay
new_lrate = train_config.config_file.lrate * (
decay_rate ** ((epoch - pre_train_epochs) / decay_steps))
for param_group in optim.param_groups:
param_group['lr'] = new_lrate
if not c_file.nonVerbose:
tqdm_range.set_description(f"epoch={epoch:<10} loss={loss:.8f} "
f"psnr={10 * torch.log10(1. / loss):.8f}")
# debug outputs
if epoch % c_file.epochsCheckpoint == 0 and epoch > 0:
train_config.save_weights(name_suffix=f"{epoch:07d}")
if epoch % c_file.epochsRender == 0 and epoch > 0:
val_data_set, _ = train_config.get_data_set_and_loader('val')
old_full_images = val_data_set.full_images
val_data_set.full_images = True
img_samples = create_sample_wrapper(val_data_set[0], train_config, True)
render_img(train_config, img_samples, img_name=f"{epoch:07d}")
val_data_set.full_images = old_full_images
rendered_video = False
if epoch % c_file.epochsVideo == 0 and epoch > 0 and c_file.epochsVideo >= 0:
render_video(train_config, vid_name=f"{epoch:07d}")
rendered_video = True
if epoch % c_file.epochsValidate == 0 and epoch > 0:
# release allocated memory! -> don't do that every single epoch, because of performance impact
batch_features, x_batch, y_batch, out = None, None, None, None
for optimizer in train_config.optimizers:
optimizer.zero_grad()
with torch.cuda.device(train_config.device):
torch.cuda.empty_cache()
val_loss = validate_batch(train_config, epoch, loss)
if val_loss < best_val_loss:
best_val_loss = val_loss
with open(f"{train_config.logDir}opt.txt", 'w') as f:
f.write(f"Optimal validation loss {best_val_loss} at epoch {epoch}")
train_config.save_weights(name_suffix="_opt")
render_all_imgs(train_config, "val_opt/", dataset_name="val")
if not rendered_video and c_file.epochsVideo >= 0:
render_video(train_config, vid_name="_opt")
elif rendered_video:
# Simply copy over the existing one, because we already rendered a video
for net_idx in range(len(train_config.models)):
shutil.copy(os.path.join(c_file.logDir, f"{epoch:07d}_{net_idx}.mp4"),
os.path.join(c_file.logDir, f"_opt_{net_idx}.mp4"))
# release allocated memory
del val_loss
with torch.cuda.device(train_config.device):
torch.cuda.empty_cache()
del batch_iterator