in egg/zoo/simple_autoenc/train.py [0:0]
def main(params):
opts = get_params(params)
device = torch.device("cuda" if opts.cuda else "cpu")
train_loader = OneHotLoader(
n_features=opts.n_features,
batch_size=opts.batch_size,
batches_per_epoch=opts.batches_per_epoch,
)
test_loader = OneHotLoader(
n_features=opts.n_features,
batch_size=opts.batch_size,
batches_per_epoch=opts.batches_per_epoch,
seed=7,
)
sender = Sender(n_hidden=opts.sender_hidden, n_features=opts.n_features)
receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
if opts.mode.lower() == "rf":
sender = core.RnnSenderReinforce(
sender,
opts.vocab_size,
opts.sender_embedding,
opts.sender_hidden,
cell=opts.sender_cell,
max_len=opts.max_len,
)
receiver = core.RnnReceiverDeterministic(
receiver,
opts.vocab_size,
opts.receiver_embedding,
opts.receiver_hidden,
cell=opts.receiver_cell,
)
game = core.SenderReceiverRnnReinforce(
sender,
receiver,
loss,
sender_entropy_coeff=opts.sender_entropy_coeff,
receiver_entropy_coeff=opts.receiver_entropy_coeff,
)
callbacks = []
elif opts.mode.lower() == "gs":
sender = core.RnnSenderGS(
sender,
opts.vocab_size,
opts.sender_embedding,
opts.sender_hidden,
cell=opts.sender_cell,
max_len=opts.max_len,
temperature=opts.temperature,
)
receiver = core.RnnReceiverGS(
receiver,
opts.vocab_size,
opts.receiver_embedding,
opts.receiver_hidden,
cell=opts.receiver_cell,
)
game = core.SenderReceiverRnnGS(sender, receiver, loss)
callbacks = [core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)]
else:
raise NotImplementedError(f"Unknown training mode, {opts.mode}")
optimizer = torch.optim.Adam(
[
{"params": game.sender.parameters(), "lr": opts.sender_lr},
{"params": game.receiver.parameters(), "lr": opts.receiver_lr},
]
)
trainer = core.Trainer(
game=game,
optimizer=optimizer,
train_data=train_loader,
validation_data=test_loader,
callbacks=callbacks + [core.ConsoleLogger(as_json=True)],
)
trainer.train(n_epochs=opts.n_epochs)
core.close()