def __init__()

in vihds/training.py [0:0]


    def __init__(self, args, settings: Config, data, parameters, model):
        """Initialise a training routine"""
        # Store arguments
        self.args = args
        self.settings = settings
        self.dataset_pair = data
        self.model = model
        # Prepare the ADAM optimizer
        self.optimizer = torch.optim.Adam(model.parameters(recurse=True), lr=settings.params.learning_rate)
        # Define learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, settings.params.learning_boundaries, gamma=settings.params.learning_gamma,
        )
        # Count the parameters
        n_vals = LocalAndGlobal.from_list(parameters.get_parameter_counts())
        self.model.n_theta = n_vals.sum()
        # Number of instances to put in a training batch.
        self.n_batch = min(settings.params.n_batch, data.n_train)
        # Values to split index batches
        # self.ds_indices = [d - 1 for d in data.train.dataset.cumulative_sizes]
        # Total number of data-points
        # self.n_data = data.train.dataset.cumulative_sizes[-1]

        # Prepare the full training and validation datasets for proper quantification
        self.train_data = batch_to_device(
            data.train.dataset.times, settings.device, data.train.dataset[data.train.indices],
        )
        self.valid_data = batch_to_device(
            data.test.dataset.times, settings.device, data.test.dataset[data.test.indices],
        )
        # self.train_data = [batch_to_device(d.times, settings.device, d) for d in data.train.dataset.datasets]
        # self.valid_data = [batch_to_device(d.times, settings.device, d) for d in data.test.dataset.datasets]

        # Training and test loaders
        self.train_loader = DataLoader(
            dataset=data.train,
            batch_size=self.n_batch,
            shuffle=True,
            collate_fn=functools.partial(collate_merged, data.train.dataset.times, settings.device),
        )

        if settings.trainer is not None:
            # Model path for storing results and tensorboard summaries
            held_out_name = args.heldout or "%d_of_%d" % (args.split, args.folds)
            self.train_path = os.path.join(self.settings.trainer.tb_log_dir, "train_%s" % held_out_name)
            self.valid_path = os.path.join(self.settings.trainer.tb_log_dir, "valid_%s" % held_out_name)
            os.makedirs(self.train_path, exist_ok=True)
            os.makedirs(self.valid_path, exist_ok=True)
        else:
            self.train_path = None
            self.valid_path = None
        self.empty_cache = True