in egg/zoo/language_bottleneck/mnist_overfit/train.py [0:0]
def main(params):
opts = get_params(params)
print(opts)
kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {}
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(
"./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
"./data", train=False, download=False, transform=transform
)
n_classes = 10
corrupt_labels_(
dataset=train_dataset, p_corrupt=opts.p_corrupt, seed=opts.random_seed + 1
)
label_mapping = torch.LongTensor([x % n_classes for x in range(100)])
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opts.batch_size, shuffle=True, **kwargs
)
train_loader = DoubleMnist(train_loader, label_mapping)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=16 * 1024, shuffle=False, **kwargs
)
test_loader = DoubleMnist(test_loader, label_mapping)
deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1
deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1
sender = Sender(
vocab_size=opts.vocab_size,
deeper=deeper_alice,
linear_channel=opts.linear_channel == 1,
softmax_channel=opts.softmax_non_linearity == 1,
)
receiver = Receiver(
vocab_size=opts.vocab_size, n_classes=n_classes, deeper=deeper_bob
)
if (
opts.softmax_non_linearity != 1
and opts.linear_channel != 1
and opts.force_discrete != 1
):
sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)
elif opts.force_discrete == 1:
sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)
game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)
optimizer = core.build_optimizer(game.parameters())
trainer = core.Trainer(
game=game,
optimizer=optimizer,
train_data=train_loader,
validation_data=test_loader,
callbacks=[
core.ConsoleLogger(as_json=True, print_train_loss=True),
EarlyStopperAccuracy(opts.early_stopping_thr),
],
)
trainer.train(n_epochs=opts.n_epochs)
core.close()