def train()

in sagemaker-voice-classification/notebook/train.py [0:0]


def train(model, epoch, train_loader, device, optimizer, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        ## oversampling
        data_resampled, target_resampled = ros.fit_resample(np.squeeze(data), target)
        data = torch.from_numpy(data_resampled)
        data = data.unsqueeze_(-2)
        target = torch.tensor(target_resampled)
        
        data, target = data.to(device), target.to(device)
        output = model(data)
        output = output.permute(1, 0, 2)[0]  # original output dimensions are batchSizex1x10
        pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
        accuracy = accuracy_score(target_resampled, pred.cpu().numpy().flatten())
        loss = F.nll_loss(output, target)  # the loss functions expects a batchSizex10 input
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {}, Loss: {:.4f}, Accuracy: {:.4f}".format(
                    epoch, loss, accuracy
                )
            )