in src/sm_augmentation_train-script.py [0:0]
def run_training_epochs(model_ft, num_epochs, criterion, optimizer_ft, dataloaders, dataset_sizes, device, USE_PYTORCH):
best_model_wts = copy.deepcopy(model_ft.state_dict())
best_acc = 0.0
total_epoch_time = 0
for epoch in range(num_epochs):
print('Running Epoch {}/{}'.format(epoch + 1, num_epochs))
epoch_start_time = time.time()
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model_ft.train()
else:
model_ft.eval()
running_loss = 0.0
running_corrects = 0
# Data iteration if using DALI Pipelines for loading the augmented data
if not USE_PYTORCH:
for i, data in enumerate(dataloaders[phase]):
inputs = data[0]["data"]
labels = data[0]["label"].squeeze(-1).long()
optimizer_ft.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer_ft.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# Data iteration if using PyTorch Dataloader for loading the augmented data
else:
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer_ft.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer_ft.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
print('{}-loss: {:.4f} {}-acc: {:.4f}'.format(
phase, epoch_loss, phase, epoch_acc))
if phase == 'val' and epoch_acc > best_acc:
best_model_wts = copy.deepcopy(model_ft.state_dict())
epoch_time_elapsed = time.time() - epoch_start_time
print('Epoch completed in {:.2f}s'.format(epoch_time_elapsed))
total_epoch_time = total_epoch_time + epoch_time_elapsed
# Calculating Seconds/ Epoch: Metric used for comparing performance for the experiemnts
print('-' * 25)
print('Seconds per Epoch: {:.2f}'.format(total_epoch_time / num_epochs))
model_ft.load_state_dict(best_model_wts)
return model_ft, best_acc