def train()

in src/train.py [0:0]


def train(train_config):
    loss = None

    best_val_loss = 1 if train_config.best_valid_loss is None else train_config.best_valid_loss
    c_file = train_config.config_file

    tqdm_range = tqdm(range(train_config.epoch0, train_config.epochs))

    # creating the iterator starts the workers (if any)
    batch_iterator = iter(train_config.train_data_loader)
    num_batches = int(math.ceil(len(train_config.train_dataset) / c_file.batchImages))

    decay_rate = train_config.config_file.lrate_decay
    decay_steps = train_config.config_file.lrate_decay_steps

    pre_train_epochs = 0

    if train_config.config_file.epochsPretrain is not None and len(train_config.config_file.epochsPretrain) != 0:
        pre_train_epochs = max(train_config.config_file.epochsPretrain)

    for epoch in tqdm_range:
        for optim in train_config.optimizers:
            optim.zero_grad()

        # get sample input data
        samples = next(batch_iterator)
        sample_data = create_sample_wrapper(samples, train_config)

        # we create a new iterator once all elements have been exhausted
        # we get a slight performance bump by creating the iterator here and not in the beginning of the next loop
        if epoch % num_batches == 0:
            batch_iterator = iter(train_config.train_data_loader)

        # inference
        outs, inference_dicts = train_config.inference(sample_data, gradient=True)

        # train net
        for out_idx, criterion in enumerate(train_config.losses):
            if criterion is None or train_config.loss_weights[out_idx] == 0 or \
                    train_config.weights_locked(epoch, out_idx):
                continue

            y_batch = sample_data.get_train_target(out_idx)
            y_batch = y_batch.reshape(y_batch.shape[0] * c_file.samples, -1)
            out = outs[out_idx]
            inference_dict = inference_dicts[out_idx]

            if len(y_batch.shape) == 3:
                y_batch = y_batch[:, 0]

            loss = criterion(out, y_batch, inference_dict=inference_dict) * train_config.loss_weights[out_idx]
            loss.backward(retain_graph=out_idx < len(outs) - 1)

        for i, optim in enumerate(train_config.optimizers):
            if not train_config.weights_locked(epoch, i):
                optim.step()

            # Learning rate decay
            new_lrate = train_config.config_file.lrate * (
                    decay_rate ** ((epoch - pre_train_epochs) / decay_steps))
            for param_group in optim.param_groups:
                param_group['lr'] = new_lrate

        if not c_file.nonVerbose:
            tqdm_range.set_description(f"epoch={epoch:<10} loss={loss:.8f} "
                                       f"psnr={10 * torch.log10(1. / loss):.8f}")

        # debug outputs
        if epoch % c_file.epochsCheckpoint == 0 and epoch > 0:
            train_config.save_weights(name_suffix=f"{epoch:07d}")

        if epoch % c_file.epochsRender == 0 and epoch > 0:
            val_data_set, _ = train_config.get_data_set_and_loader('val')
            old_full_images = val_data_set.full_images
            val_data_set.full_images = True
            img_samples = create_sample_wrapper(val_data_set[0], train_config, True)
            render_img(train_config, img_samples, img_name=f"{epoch:07d}")
            val_data_set.full_images = old_full_images

        rendered_video = False
        if epoch % c_file.epochsVideo == 0 and epoch > 0 and c_file.epochsVideo >= 0:
            render_video(train_config, vid_name=f"{epoch:07d}")
            rendered_video = True

        if epoch % c_file.epochsValidate == 0 and epoch > 0:
            # release allocated memory! -> don't do that every single epoch, because of performance impact
            batch_features, x_batch, y_batch, out = None, None, None, None
            for optimizer in train_config.optimizers:
                optimizer.zero_grad()

            with torch.cuda.device(train_config.device):
                torch.cuda.empty_cache()

            val_loss = validate_batch(train_config, epoch, loss)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                with open(f"{train_config.logDir}opt.txt", 'w') as f:
                    f.write(f"Optimal validation loss {best_val_loss} at epoch {epoch}")

                train_config.save_weights(name_suffix="_opt")

                render_all_imgs(train_config, "val_opt/", dataset_name="val")

                if not rendered_video and c_file.epochsVideo >= 0:
                    render_video(train_config, vid_name="_opt")
                elif rendered_video:
                    # Simply copy over the existing one, because we already rendered a video
                    for net_idx in range(len(train_config.models)):
                        shutil.copy(os.path.join(c_file.logDir, f"{epoch:07d}_{net_idx}.mp4"),
                                    os.path.join(c_file.logDir, f"_opt_{net_idx}.mp4"))

            # release allocated memory
            del val_loss

            with torch.cuda.device(train_config.device):
                torch.cuda.empty_cache()

    del batch_iterator