def train()

in pytorch_alternatives/custom_pytorch_nlp/src/main.py [0:0]


def train(args):
    ###### Load data from input channels ############
    train_loader = load_training_data(args.train)
    test_loader = load_testing_data(args.test)
    embedding_matrix = load_embeddings(args.embeddings)

    ###### Setup model architecture ############
    model = Net(
        vocab_size=embedding_matrix.shape[0],
        emb_dim=embedding_matrix.shape[1],
        num_classes=args.num_classes,
    )
    model.embedding.weight = torch.nn.parameter.Parameter(torch.FloatTensor(embedding_matrix), False)
    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda")
    model.to(device)
    optimizer = optim.RMSprop(model.parameters(), lr=args.learning_rate)

    for epoch in range(1, args.epochs + 1):
        model.train()
        running_loss = 0.0
        n_batches = 0
        for batch_idx, (X_train, y_train) in enumerate(train_loader, 1):
            data, target = X_train.to(device), y_train.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.binary_cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n_batches += 1
        print(f"epoch: {epoch}, train_loss: {running_loss / n_batches:.6f}")  # (Avg over batches)
        print("Evaluating model")
        test(model, test_loader, device)
    save_model(model, args.model_dir, args.max_seq_len)