in sagemaker-voice-classification/notebook/train.py [0:0]
def train(model, epoch, train_loader, device, optimizer, log_interval):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
## oversampling
data_resampled, target_resampled = ros.fit_resample(np.squeeze(data), target)
data = torch.from_numpy(data_resampled)
data = data.unsqueeze_(-2)
target = torch.tensor(target_resampled)
data, target = data.to(device), target.to(device)
output = model(data)
output = output.permute(1, 0, 2)[0] # original output dimensions are batchSizex1x10
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
accuracy = accuracy_score(target_resampled, pred.cpu().numpy().flatten())
loss = F.nll_loss(output, target) # the loss functions expects a batchSizex10 input
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
"Train Epoch: {}, Loss: {:.4f}, Accuracy: {:.4f}".format(
epoch, loss, accuracy
)
)