in salina_examples/computer_vision/mnist/mnist_dataloader_torch_agent.py [0:0]
def main():
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--max_epochs",
type=int,
default=10000,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr",
type=float,
default=0.001,
metavar="LR",
help="learning rate (default: 1.0)",
)
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--no-verbose", action="store_true", default=False, help="Output on console"
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument(
"--log-dir",
type=str,
default="./tmp",
metavar="N",
help="Directory for logging",
)
parser.add_argument(
"--data-dir",
type=str,
default="./.data",
metavar="N",
help="Directory for logging",
)
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cpu")
if use_cuda:
device = torch.device("cuda:0")
torch.manual_seed(args.seed)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
train_dataset = datasets.MNIST(
args.data_dir, train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(args.data_dir, train=False, transform=transform)
train_agent = ShuffledDatasetAgent(
train_dataset, batch_size=args.batch_size, output_names=("x", "y")
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=4,
persistent_workers=True,
)
test_agent = DataLoaderAgent(test_dataloader, output_names=("x", "y"))
agent = ConvNetAgent()
train_agent = Agents(train_agent, agent)
train_agent.seed(0)
test_agent.seed(1)
logger = TFLogger(
log_dir=args.log_dir,
hps={k: v for k, v in args.__dict__.items()},
every_n_seconds=10,
verbose=not args.no_verbose,
)
optimizer = torch.optim.Adam(train_agent.parameters(), lr=args.lr)
train_agent.to(device)
test_agent.to(device)
train_workspace = Workspace()
iteration = 0
avg_loss = None
for epoch in range(args.max_epochs):
print(f"-- Training, Epoch {epoch+1}")
loss, accuracy = test(test_agent, agent)
logger.add_scalar("test/loss", loss.item(), epoch)
logger.add_scalar("test/accuracy", accuracy, epoch)
agent.train()
for k in range(int(len(train_dataset) / args.batch_size)):
train_agent(train_workspace)
y = train_workspace.get("y", 0)
pred = train_workspace.get("py", 0)
loss = F.cross_entropy(pred, y, reduction="none")
loss = loss.mean()
logger.add_scalar("train/loss", loss.item(), iteration)
optimizer.zero_grad()
loss.backward()
optimizer.step()
iteration += 1
print("Done!")