in python/dpu_utils/ptutils/componenttrainer.py [0:0]
def train(self, training_data: Iterable[InputData], validation_data: Iterable[InputData],
show_progress_bar: bool = True, patience: int = 5, initialize_metadata: bool = True,
exponential_running_average_factor: float = 0.97, get_parameters_to_freeze: Optional[Callable[[], Set]] = None,
parallel_minibatch_creation: bool=False, device: Optional[Union[str, torch.device]] = None) -> None:
"""
The training-validation loop for `BaseComponent`s.
:param training_data: An iterable that each iteration yields the full training data.
:param validation_data: An iterable that each iteration yields the full validation data.
:param show_progress_bar: Show a progress bar
:param patience: The number of iterations before early stopping kicks in.
:param initialize_metadata: If true, initialize the metadata from the training_data. Otherwise,
assume that the model that is being trained has its metadata already initialized.
:param exponential_running_average_factor: The factor of the running average of the training loss
displayed in the progress bar.
:param get_parameters_to_freeze: The (optional) callable that returns the set of parameters to freeze during training.
:param parallel_minibatch_creation: If True the minibatches will be created in a separate thread.
"""
if initialize_metadata:
self.__load_metadata(training_data)
self.LOGGER.info('Model has %s parameters', self.__model.num_parameters())
self.LOGGER.debug('Data Tensorization Started...')
def data_to_tensor_iterator(data):
for datapoint in data:
tensorized_datapoint = self.__model.load_data_from_sample(datapoint)
if tensorized_datapoint is not None:
yield tensorized_datapoint
def training_tensors():
yield from ThreadedIterator(
original_iterator=data_to_tensor_iterator(training_data),
max_queue_size=10 * self.__minibatch_size
)
def validation_tensors():
yield from ThreadedIterator(
original_iterator=data_to_tensor_iterator(validation_data),
max_queue_size=10 * self.__minibatch_size
)
def minibatch_iterator(data_iterator: Iterator[TensorizedData], return_partial_minibatches: bool = False) -> Tuple[Dict, int]:
while True:
mb_data, batch_is_full, num_elements = self.__model.create_minibatch(data_iterator, max_num_items=self.__minibatch_size)
if num_elements == 0:
break
elif not batch_is_full and not return_partial_minibatches:
break # Do not return partial minibatches when the iterator is exhausted.
else:
yield mb_data, num_elements
if device is None:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.__model.to(device)
self.LOGGER.info('Using %s for training.' % device)
if get_parameters_to_freeze is None:
get_parameters_to_freeze = lambda: set()
trainable_parameters = set(self.__model.parameters()) - get_parameters_to_freeze()
optimizer = self.__create_optimizer(trainable_parameters)
scheduler = None if self.__create_scheduler is None else self.__create_scheduler(optimizer)
for hook in self.__training_start_hooks:
hook(self.__model, optimizer)
best_loss = float('inf') # type: float
num_epochs_not_improved = 0 # type: int
for epoch in range(self.__max_num_epochs):
self.__model.train()
data_iter = shuffled_iterator(training_tensors())
sum_epoch_loss = 0.0
running_avg_loss = 0.0
num_minibatches = 0
num_samples = 0
start_time = time.time()
self.__model.reset_metrics()
with tqdm(desc='Training', disable=not show_progress_bar, leave=False) as progress_bar:
for step_idx, (mb_data, num_elements) in enumerate(ThreadedIterator(
minibatch_iterator(data_iter, return_partial_minibatches=False),
enabled=parallel_minibatch_creation)):
optimizer.zero_grad()
mb_loss = self.__model(**mb_data)
mb_loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
loss = float(mb_loss.cpu())
if math.isnan(loss):
raise Exception('Training Loss has a NaN value.')
sum_epoch_loss += loss
num_minibatches += 1
num_samples += num_elements
if num_minibatches == 1: # First minibatch
running_avg_loss = loss
else:
running_avg_loss = exponential_running_average_factor * running_avg_loss + (
1 - exponential_running_average_factor) * loss
progress_bar.update()
progress_bar.set_postfix(Loss=f'{running_avg_loss:.2f}')
elapsed_time = time.time() - start_time # type: float
self.LOGGER.info('Training complete in %.1fsec [%.2f samples/sec]', elapsed_time,
(num_samples / elapsed_time))
assert num_minibatches > 0, 'No training minibatches were created. The minibatch size may be too large or the training dataset size too small.'
self.LOGGER.info('Epoch %i: Avg Train Loss %.2f', epoch + 1, sum_epoch_loss / num_minibatches)
train_metrics = self.__model.report_metrics()
for epoch_hook in self.__train_epoch_end_hooks:
epoch_hook(self.__model, epoch, train_metrics)
if len(train_metrics) > 0:
self.LOGGER.info('Training Metrics: %s', json.dumps(train_metrics, indent=2))
# Now do validation!
self.__model.eval()
data_iter = validation_tensors()
sum_epoch_loss = 0
num_minibatches = 0
num_samples = 0
start_time = time.time()
self.__model.reset_metrics()
with tqdm(desc='Validation', disable=not show_progress_bar, leave=False) as progress_bar, torch.no_grad():
for mb_data, num_elements in ThreadedIterator(
minibatch_iterator(data_iter, return_partial_minibatches=True),
enabled=parallel_minibatch_creation):
mb_loss = self.__model(**mb_data)
loss = float(mb_loss.cpu())
if math.isnan(loss):
raise Exception('Validation Loss has a NaN value.')
sum_epoch_loss += loss
num_minibatches += 1
num_samples += num_elements
progress_bar.update()
progress_bar.set_postfix(Loss=f'{sum_epoch_loss / num_minibatches:.2f}')
elapsed_time = time.time() - start_time
assert num_samples > 0, 'No validation data was found.'
validation_loss = sum_epoch_loss / num_minibatches
self.LOGGER.info('Validation complete in %.1fsec [%.2f samples/sec]', elapsed_time,
(num_samples / elapsed_time))
self.LOGGER.info('Epoch %i: Avg Valid Loss %.2f', epoch + 1, validation_loss)
validation_metrics = self.__model.report_metrics()
for epoch_hook in self.__validation_epoch_end_hooks:
epoch_hook(self.__model, epoch, validation_metrics)
if len(validation_metrics) > 0:
self.LOGGER.info('Validation Metrics: %s', json.dumps(validation_metrics, indent=2))
if validation_loss < best_loss:
self.LOGGER.info('Best loss so far --- Saving model.')
num_epochs_not_improved = 0
self.__save_current_model()
best_loss = validation_loss
else:
num_epochs_not_improved += 1
if num_epochs_not_improved > patience:
self.LOGGER.warning('After %s epochs loss has not improved. Stopping.', num_epochs_not_improved)
break
# Restore the best model that was found.
self.restore_model()