in salina_examples/computer_vision/sequential_mnist/sequential_mnist_dataloader_torch_agent.py [0:0]
def main():
# Training settings
parser = argparse.ArgumentParser(description="PyTorch Sequential MNIST Classification Example")
parser.add_argument(
"--batch-size",
type=int,
default=16,
metavar="N",
help="input batch size for training (default: 16)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=16,
metavar="N",
help="input batch size for testing (default: 16)",
)
parser.add_argument(
"--max_epochs",
type=int,
default=10000,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr",
type=float,
default=5e-3,
metavar="LR",
help="learning rate (default: 5e-3)",
)
parser.add_argument(
"--clip",
type=float,
default=1.0,
metavar="CP",
help="clipping threshold (default: 1.0)",
)
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--no-verbose", action="store_true", default=False, help="Output on console"
)
parser.add_argument(
"--permute", action="store_true", default=False, help="Permutate pixels"
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument(
"--log-dir",
type=str,
default="./tmp",
metavar="N",
help="Directory for logging",
)
parser.add_argument(
"--data-dir",
type=str,
default="./.data",
metavar="N",
help="Directory for logging",
)
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cpu")
if use_cuda:
device = torch.device("cuda:0")
torch.manual_seed(args.seed)
train_dataset = SequentialMNIST(data_dir=args.data_dir, train=True, permute=args.permute)
test_dataset = SequentialMNIST(data_dir=args.data_dir, train=False, permute=args.permute)
train_agent = ShuffledDatasetAgent(
train_dataset, batch_size=args.batch_size, output_names=("x", "y")
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=4,
persistent_workers=True,
)
test_agent = DataLoaderAgent(test_dataloader, output_names=("x", "y"))
agent = IRNNAgent()
train_agent = Agents(train_agent, agent)
train_agent.seed(0)
test_agent.seed(1)
logger = TFLogger(
log_dir=args.log_dir,
hps={k: v for k, v in args.__dict__.items()},
every_n_seconds=10,
verbose=not args.no_verbose,
)
optimizer = torch.optim.SGD(train_agent.parameters(), lr=args.lr)
train_agent.to(device)
test_agent.to(device)
train_workspace = Workspace()
iteration = 0
avg_loss = None
for epoch in range(args.max_epochs):
print(f"-- Training, Epoch {epoch+1}")
loss, accuracy = test(test_agent, agent)
logger.add_scalar("test/loss", loss.item(), epoch)
logger.add_scalar("test/accuracy", accuracy, epoch)
agent.train()
for k in range(int(len(train_dataset) / args.batch_size)):
train_agent(train_workspace)
y = train_workspace.get("y", 0)
pred = train_workspace.get("py", 0)
loss = F.cross_entropy(pred, y, reduction="none")
loss = loss.mean()
logger.add_scalar("train/loss", loss.item(), iteration)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(train_agent.parameters(), args.clip)
optimizer.step()
iteration += 1
print("Done!")