in src/train.py [0:0]
def pre_train(train_config):
if train_config.config_file.epochsPretrain is None or len(train_config.config_file.epochsPretrain) == 0:
return
c_file = train_config.config_file
decay_rate = c_file.lrate_decay
decay_steps = c_file.lrate_decay_steps
batch_images = c_file.batchImages
samples = c_file.samples
if c_file.batchImagesPretrain != -1:
batch_images = c_file.batchImagesPretrain
if c_file.samplesPretrain != -1:
samples = c_file.samplesPretrain
train_config.train_dataset.num_samples = samples
# pretrain data_loader can have different batch size
data_loader = train_config.train_data_loader
if train_config.pretrain_data_loader is not None:
data_loader = train_config.pretrain_data_loader
for model_idx in range(len(train_config.models)):
epoch_pretrain = train_config.config_file.epochsPretrain[model_idx]
if train_config.epoch0 >= epoch_pretrain:
continue
best_val_loss = 1.
if model_idx < len(train_config.best_valid_loss_pretrain):
best_val_loss = train_config.best_valid_loss_pretrain[model_idx]
model = train_config.models[model_idx]
optim = train_config.optimizers[model_idx]
criterion = train_config.losses[model_idx]
f_in = train_config.f_in[model_idx]
batch_iterator = iter(data_loader)
num_batches = int(math.ceil(len(train_config.train_dataset) / batch_images))
tqdm_range = tqdm(range(train_config.epoch0, epoch_pretrain + 1), desc=f"pre-training model {model_idx}")
for epoch in tqdm_range:
optim.zero_grad()
# get sample input data
batch_samples = next(batch_iterator)
sample_data = create_sample_wrapper(batch_samples, train_config)
# we create a new iterator once all elements have been exhausted
# we get a slight performance bump by creating the iterator here and not in the beginning of the next loop
if epoch % num_batches == 0:
batch_iterator = iter(train_config.train_data_loader)
prev_outs = []
for prev_model_idx in range(model_idx):
prev_outs.append(sample_data.get_train_target(prev_model_idx))
inference_dict = f_in.batch(sample_data.get_batch_input(model_idx), prev_outs=prev_outs)
x_batch = inference_dict[FeatureSetKeyConstants.input_feature_batch]
y_batch = sample_data.get_train_target(model_idx)
y_batch = y_batch.reshape(y_batch.shape[0] * samples, -1)
out = model(x_batch)
loss = criterion(out, y_batch, inference_dict=inference_dict)
loss.backward()
optim.step()
# Learning rate decay
new_lrate = c_file.lrate * (decay_rate ** ((train_config.epoch0 + epoch) / decay_steps))
for param_group in optim.param_groups:
param_group['lr'] = new_lrate
if not c_file.nonVerbose:
tqdm_range.set_description(f"epoch={epoch:<10} loss={loss:.8f}")
if epoch > 0 and epoch % train_config.config_file.epochsCheckpoint == 0:
train_config.save_weights(name_suffix=f"{epoch:07d}")
# debug outputs
if epoch % c_file.epochsRender == 0 and epoch > 0:
val_data_set, _ = train_config.get_data_set_and_loader('val')
old_full_images = val_data_set.full_images
val_data_set.full_images = True
img_samples = create_sample_wrapper(val_data_set[0], train_config, True)
model_idxs = None if train_config.config_file.preTrained else [model_idx]
render_img(train_config, img_samples, img_name=f"{epoch:07d}", model_idxs=model_idxs)
val_data_set.full_images = old_full_images
# debug outputs
if epoch % c_file.epochsValidate == 0 and epoch > 0:
# release allocated memory! -> don't do that every single epoch, because of performance impact
x_batch, y_batch, out = None, None, None
for optimizer in train_config.optimizers:
optimizer.zero_grad()
with torch.cuda.device(train_config.device):
torch.cuda.empty_cache()
val_loss = validate_batch(train_config, epoch, loss, model_idx)
if val_loss < best_val_loss:
best_val_loss = val_loss
with open(f"{train_config.logDir}opt.txt", 'w') as f:
f.write(f"Optimal validation loss {best_val_loss} at epoch {epoch}")
train_config.save_weights(name_suffix="_opt", model_idx=model_idx)
del val_loss
train_config.load_specific_weights(c_file.checkPointName, model_idx)
train_config.epoch0 = epoch_pretrain
# restore normal sample count
train_config.train_dataset.num_samples = c_file.samples
print("pre-training finished")