def run_training_epochs()

in src/sm_augmentation_train-script.py [0:0]


def run_training_epochs(model_ft, num_epochs, criterion, optimizer_ft, dataloaders, dataset_sizes, device, USE_PYTORCH):
    best_model_wts = copy.deepcopy(model_ft.state_dict())
    best_acc = 0.0

    total_epoch_time = 0
    for epoch in range(num_epochs):
        print('Running Epoch {}/{}'.format(epoch + 1, num_epochs))

        epoch_start_time = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:

            if phase == 'train':
                model_ft.train()
            else:
                model_ft.eval()

            running_loss = 0.0
            running_corrects = 0

            # Data iteration if using DALI Pipelines for loading the augmented data
            if not USE_PYTORCH:

                for i, data in enumerate(dataloaders[phase]):
                    inputs = data[0]["data"]
                    labels = data[0]["label"].squeeze(-1).long()

                    optimizer_ft.zero_grad()
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model_ft(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)
                        if phase == 'train':
                            loss.backward()
                            optimizer_ft.step()
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

            # Data iteration if using PyTorch Dataloader for loading the augmented data
            else:

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer_ft.zero_grad()
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model_ft(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)
                        if phase == 'train':
                            loss.backward()
                            optimizer_ft.step()
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            print('{}-loss: {:.4f} {}-acc: {:.4f}'.format(
                phase, epoch_loss, phase, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_model_wts = copy.deepcopy(model_ft.state_dict())

        epoch_time_elapsed = time.time() - epoch_start_time
        print('Epoch completed in {:.2f}s'.format(epoch_time_elapsed))
        total_epoch_time = total_epoch_time + epoch_time_elapsed

    # Calculating Seconds/ Epoch: Metric used for comparing performance for the experiemnts
    print('-' * 25)
    print('Seconds per Epoch: {:.2f}'.format(total_epoch_time / num_epochs))

    model_ft.load_state_dict(best_model_wts)
    return model_ft, best_acc