def main()

in salina_examples/computer_vision/mnist/mnist_spatial_transformer_network.py [0:0]


def main():
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument(
        "--num-transforms",
        type=int,
        default=3,
        metavar="N",
        help="number of STN transformations (default: 3)",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1000,
        metavar="N",
        help="input batch size for testing (default: 1000)",
    )
    parser.add_argument(
        "--max_epochs",
        type=int,
        default=10000,
        metavar="N",
        help="number of epochs to train (default: 14)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.001,
        metavar="LR",
        help="learning rate (default: 1.0)",
    )
    parser.add_argument(
        "--no-cuda", action="store_true", default=False, help="disables CUDA training"
    )
    parser.add_argument(
        "--no-verbose", action="store_true", default=False, help="Output on console"
    )
    parser.add_argument(
        "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
    )
    parser.add_argument(
        "--log-dir",
        type=str,
        default="./tmp",
        metavar="N",
        help="Directory for logging",
    )
    parser.add_argument(
        "--data-dir",
        type=str,
        default="./.data",
        metavar="N",
        help="Directory for logging",
    )
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cpu")
    if use_cuda:
        device = torch.device("cuda:0")
    torch.manual_seed(args.seed)

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    train_dataset = datasets.MNIST(
        args.data_dir, train=True, download=True, transform=transform
    )
    test_dataset = datasets.MNIST(args.data_dir, train=False, transform=transform)

    train_agent = ShuffledDatasetAgent(
        train_dataset, batch_size=args.batch_size, output_names=("x", "y")
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        num_workers=4,
        persistent_workers=True,
    )
    test_agent = DataLoaderAgent(test_dataloader, output_names=("x", "y"))
    agent = STNAgent(input="x", output="py", num_transforms=args.num_transforms)
    train_agent = Agents(train_agent, agent)
    train_agent.seed(0)
    test_agent.seed(1)

    logger = TFLogger(
        log_dir=args.log_dir,
        hps={k: v for k, v in args.__dict__.items()},
        every_n_seconds=10,
        verbose=not args.no_verbose,
    )

    optimizer = torch.optim.Adam(train_agent.parameters(), lr=args.lr)

    train_agent.to(device)
    test_agent.to(device)

    train_workspace = Workspace()

    iteration = 0
    avg_loss = None
    for epoch in range(args.max_epochs):
        print(f"-- Training, Epoch {epoch+1}")

        loss, accuracy = test(test_agent, agent)
        logger.add_scalar("test/loss", loss.item(), epoch)
        logger.add_scalar("test/accuracy", accuracy, epoch)

        agent.train()
        for k in range(int(len(train_dataset) / args.batch_size)):
            train_agent(train_workspace)
            y = train_workspace.get("y", 0)
            pred = train_workspace.get("py", 0)
            loss = F.cross_entropy(pred, y, reduction="none")
            loss = loss.mean()

            logger.add_scalar("train/loss", loss.item(), iteration)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            iteration += 1

    print("Done!")