def pre_train()

in src/train.py [0:0]


def pre_train(train_config):
    if train_config.config_file.epochsPretrain is None or len(train_config.config_file.epochsPretrain) == 0:
        return

    c_file = train_config.config_file

    decay_rate = c_file.lrate_decay
    decay_steps = c_file.lrate_decay_steps

    batch_images = c_file.batchImages
    samples = c_file.samples

    if c_file.batchImagesPretrain != -1:
        batch_images = c_file.batchImagesPretrain
    if c_file.samplesPretrain != -1:
        samples = c_file.samplesPretrain

    train_config.train_dataset.num_samples = samples

    # pretrain data_loader can have different batch size
    data_loader = train_config.train_data_loader
    if train_config.pretrain_data_loader is not None:
        data_loader = train_config.pretrain_data_loader

    for model_idx in range(len(train_config.models)):
        epoch_pretrain = train_config.config_file.epochsPretrain[model_idx]
        if train_config.epoch0 >= epoch_pretrain:
            continue

        best_val_loss = 1.
        if model_idx < len(train_config.best_valid_loss_pretrain):
            best_val_loss = train_config.best_valid_loss_pretrain[model_idx]

        model = train_config.models[model_idx]
        optim = train_config.optimizers[model_idx]
        criterion = train_config.losses[model_idx]
        f_in = train_config.f_in[model_idx]

        batch_iterator = iter(data_loader)
        num_batches = int(math.ceil(len(train_config.train_dataset) / batch_images))

        tqdm_range = tqdm(range(train_config.epoch0, epoch_pretrain + 1), desc=f"pre-training model {model_idx}")

        for epoch in tqdm_range:
            optim.zero_grad()

            # get sample input data
            batch_samples = next(batch_iterator)
            sample_data = create_sample_wrapper(batch_samples, train_config)

            # we create a new iterator once all elements have been exhausted
            # we get a slight performance bump by creating the iterator here and not in the beginning of the next loop
            if epoch % num_batches == 0:
                batch_iterator = iter(train_config.train_data_loader)

            prev_outs = []
            for prev_model_idx in range(model_idx):
                prev_outs.append(sample_data.get_train_target(prev_model_idx))

            inference_dict = f_in.batch(sample_data.get_batch_input(model_idx), prev_outs=prev_outs)
            x_batch = inference_dict[FeatureSetKeyConstants.input_feature_batch]
            y_batch = sample_data.get_train_target(model_idx)
            y_batch = y_batch.reshape(y_batch.shape[0] * samples, -1)
            out = model(x_batch)

            loss = criterion(out, y_batch, inference_dict=inference_dict)
            loss.backward()
            optim.step()

            # Learning rate decay
            new_lrate = c_file.lrate * (decay_rate ** ((train_config.epoch0 + epoch) / decay_steps))
            for param_group in optim.param_groups:
                param_group['lr'] = new_lrate

            if not c_file.nonVerbose:
                tqdm_range.set_description(f"epoch={epoch:<10} loss={loss:.8f}")

            if epoch > 0 and epoch % train_config.config_file.epochsCheckpoint == 0:
                train_config.save_weights(name_suffix=f"{epoch:07d}")

            # debug outputs
            if epoch % c_file.epochsRender == 0 and epoch > 0:
                val_data_set, _ = train_config.get_data_set_and_loader('val')
                old_full_images = val_data_set.full_images
                val_data_set.full_images = True
                img_samples = create_sample_wrapper(val_data_set[0], train_config, True)
                model_idxs = None if train_config.config_file.preTrained else [model_idx]
                render_img(train_config, img_samples, img_name=f"{epoch:07d}", model_idxs=model_idxs)
                val_data_set.full_images = old_full_images

            # debug outputs
            if epoch % c_file.epochsValidate == 0 and epoch > 0:
                # release allocated memory! -> don't do that every single epoch, because of performance impact
                x_batch, y_batch, out = None, None, None
                for optimizer in train_config.optimizers:
                    optimizer.zero_grad()

                with torch.cuda.device(train_config.device):
                    torch.cuda.empty_cache()

                val_loss = validate_batch(train_config, epoch, loss, model_idx)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    with open(f"{train_config.logDir}opt.txt", 'w') as f:
                        f.write(f"Optimal validation loss {best_val_loss} at epoch {epoch}")

                    train_config.save_weights(name_suffix="_opt", model_idx=model_idx)

                del val_loss

        train_config.load_specific_weights(c_file.checkPointName, model_idx)
        train_config.epoch0 = epoch_pretrain

    # restore normal sample count
    train_config.train_dataset.num_samples = c_file.samples
    print("pre-training finished")