def train()

in source/sagemaker/src/package/model/train.py [0:0]


def train(model, train_iterator, valid_iterator, n_epochs, model_dir):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    optimizer = optim.Adam(model.parameters())

    criterion = nn.BCEWithLogitsLoss()

    model = model.to(device)
    criterion = criterion.to(device)

    best_valid_loss = float('inf')

    model_path = join(model_dir, "model.pt")

    for epoch in range(n_epochs):

        print(f'Epoch: {epoch + 1:02} started...')

        start_time = time.time()

        train_loss, train_acc = train_one_epoch(model, train_iterator, optimizer, criterion)
        valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)

        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), model_path)

        print(f'Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc * 100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc * 100:.2f}%')