in egg/zoo/language_bottleneck/guess_number/train.py [0:0]
def main(params):
opts = get_params(params)
print(opts)
device = opts.device
train_loader = OneHotLoader(
n_bits=opts.n_bits,
bits_s=opts.bits_s,
bits_r=opts.bits_r,
batch_size=opts.batch_size,
batches_per_epoch=opts.n_examples_per_epoch / opts.batch_size,
)
test_loader = UniformLoader(
n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r
)
test_loader.batch = [x.to(device) for x in test_loader.batch]
if not opts.variable_length:
sender = Sender(
n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.vocab_size
)
if opts.mode == "gs":
sender = core.GumbelSoftmaxWrapper(
agent=sender, temperature=opts.temperature
)
receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
receiver = core.SymbolReceiverWrapper(
receiver,
vocab_size=opts.vocab_size,
agent_input_size=opts.receiver_hidden,
)
game = core.SymbolGameGS(sender, receiver, diff_loss)
elif opts.mode == "rf":
sender = core.ReinforceWrapper(agent=sender)
receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
receiver = core.SymbolReceiverWrapper(
receiver,
vocab_size=opts.vocab_size,
agent_input_size=opts.receiver_hidden,
)
receiver = core.ReinforceDeterministicWrapper(agent=receiver)
game = core.SymbolGameReinforce(
sender,
receiver,
diff_loss,
sender_entropy_coeff=opts.sender_entropy_coeff,
)
elif opts.mode == "non_diff":
sender = core.ReinforceWrapper(agent=sender)
receiver = ReinforcedReceiver(
n_bits=opts.n_bits, n_hidden=opts.receiver_hidden
)
receiver = core.SymbolReceiverWrapper(
receiver,
vocab_size=opts.vocab_size,
agent_input_size=opts.receiver_hidden,
)
game = core.SymbolGameReinforce(
sender,
receiver,
non_diff_loss,
sender_entropy_coeff=opts.sender_entropy_coeff,
receiver_entropy_coeff=opts.receiver_entropy_coeff,
)
else:
if opts.mode != "rf":
print("Only mode=rf is supported atm")
opts.mode = "rf"
if opts.sender_cell == "transformer":
receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
sender = Sender(
n_bits=opts.n_bits,
n_hidden=opts.sender_hidden,
vocab_size=opts.sender_hidden,
) # TODO: not really vocab
sender = core.TransformerSenderReinforce(
agent=sender,
vocab_size=opts.vocab_size,
embed_dim=opts.sender_emb,
max_len=opts.max_len,
num_layers=1,
num_heads=1,
hidden_size=opts.sender_hidden,
)
else:
receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
sender = Sender(
n_bits=opts.n_bits,
n_hidden=opts.sender_hidden,
vocab_size=opts.sender_hidden,
) # TODO: not really vocab
sender = core.RnnSenderReinforce(
agent=sender,
vocab_size=opts.vocab_size,
embed_dim=opts.sender_emb,
hidden_size=opts.sender_hidden,
max_len=opts.max_len,
cell=opts.sender_cell,
)
if opts.receiver_cell == "transformer":
receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_emb)
receiver = core.TransformerReceiverDeterministic(
receiver,
opts.vocab_size,
opts.max_len,
opts.receiver_emb,
num_heads=1,
hidden_size=opts.receiver_hidden,
num_layers=1,
)
else:
receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
receiver = core.RnnReceiverDeterministic(
receiver,
opts.vocab_size,
opts.receiver_emb,
opts.receiver_hidden,
cell=opts.receiver_cell,
)
game = core.SenderReceiverRnnGS(sender, receiver, diff_loss)
game = core.SenderReceiverRnnReinforce(
sender,
receiver,
diff_loss,
sender_entropy_coeff=opts.sender_entropy_coeff,
receiver_entropy_coeff=opts.receiver_entropy_coeff,
)
optimizer = torch.optim.Adam(
[
dict(params=sender.parameters(), lr=opts.sender_lr),
dict(params=receiver.parameters(), lr=opts.receiver_lr),
]
)
loss = game.loss
intervention = CallbackEvaluator(
test_loader,
device=device,
is_gs=opts.mode == "gs",
loss=loss,
var_length=opts.variable_length,
input_intervention=True,
)
trainer = core.Trainer(
game=game,
optimizer=optimizer,
train_data=train_loader,
validation_data=test_loader,
callbacks=[
core.ConsoleLogger(as_json=True),
EarlyStopperAccuracy(opts.early_stopping_thr),
intervention,
],
)
trainer.train(n_epochs=opts.n_epochs)
core.close()