in vihds/training.py [0:0]
def __init__(self, args, settings: Config, data, parameters, model):
"""Initialise a training routine"""
# Store arguments
self.args = args
self.settings = settings
self.dataset_pair = data
self.model = model
# Prepare the ADAM optimizer
self.optimizer = torch.optim.Adam(model.parameters(recurse=True), lr=settings.params.learning_rate)
# Define learning rate scheduler
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer, settings.params.learning_boundaries, gamma=settings.params.learning_gamma,
)
# Count the parameters
n_vals = LocalAndGlobal.from_list(parameters.get_parameter_counts())
self.model.n_theta = n_vals.sum()
# Number of instances to put in a training batch.
self.n_batch = min(settings.params.n_batch, data.n_train)
# Values to split index batches
# self.ds_indices = [d - 1 for d in data.train.dataset.cumulative_sizes]
# Total number of data-points
# self.n_data = data.train.dataset.cumulative_sizes[-1]
# Prepare the full training and validation datasets for proper quantification
self.train_data = batch_to_device(
data.train.dataset.times, settings.device, data.train.dataset[data.train.indices],
)
self.valid_data = batch_to_device(
data.test.dataset.times, settings.device, data.test.dataset[data.test.indices],
)
# self.train_data = [batch_to_device(d.times, settings.device, d) for d in data.train.dataset.datasets]
# self.valid_data = [batch_to_device(d.times, settings.device, d) for d in data.test.dataset.datasets]
# Training and test loaders
self.train_loader = DataLoader(
dataset=data.train,
batch_size=self.n_batch,
shuffle=True,
collate_fn=functools.partial(collate_merged, data.train.dataset.times, settings.device),
)
if settings.trainer is not None:
# Model path for storing results and tensorboard summaries
held_out_name = args.heldout or "%d_of_%d" % (args.split, args.folds)
self.train_path = os.path.join(self.settings.trainer.tb_log_dir, "train_%s" % held_out_name)
self.valid_path = os.path.join(self.settings.trainer.tb_log_dir, "valid_%s" % held_out_name)
os.makedirs(self.train_path, exist_ok=True)
os.makedirs(self.valid_path, exist_ok=True)
else:
self.train_path = None
self.valid_path = None
self.empty_cache = True