def main()

in evals/elsuite/hr_ml_agent_bench/benchmarks/cifar10/env/train.py [0:0]


def main():
    # Set device for training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the CIFAR-10 dataset
    train_dataset = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transform,
    )

    test_dataset = datasets.CIFAR10(
        root="./data",
        train=False,
        download=True,
        transform=transform,
    )

    # Define the dataloaders
    batch_size = 32

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=4,
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=4,
    )

    # Define the model, optimizer, and loss function
    model = Net().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    # Train the model
    epochs = 5

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for i, (inputs, labels) in enumerate(train_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if i % 100 == 99:
                print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}")
                running_loss = 0.0

        train_accuracy = test_model(model, device, train_dataloader)
        test_accuracy = test_model(model, device, test_dataloader)

        print(
            f"Epoch [{epoch+1}/{epochs}], Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%"
        )

    # Save the predictions to submission.csv
    submission = pd.DataFrame(columns=list(range(10)), index=range(len(test_dataset)))
    model.eval()

    for idx, data in enumerate(test_dataset):
        inputs = data[0].unsqueeze(0).to(device)
        pred = model(inputs)
        pred = torch.softmax(pred[0], dim=0)
        submission.loc[idx] = pred.tolist()

    submission.to_csv("submission.csv")