in egg/zoo/channel/train.py [0:0]
def main(params):
opts = get_params(params)
print(opts, flush=True)
# For compatibility, after https://github.com/facebookresearch/EGG/pull/130
# the meaning of `length` changed a bit. Before it included the EOS symbol; now
# it doesn't. To ensure that hyperparameters/CL arguments do not change,
# we subtract it here.
opts.max_len -= 1
device = opts.device
if opts.probs == "uniform":
probs = np.ones(opts.n_features)
elif opts.probs == "powerlaw":
probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
else:
probs = np.array([float(x) for x in opts.probs.split(",")], dtype=np.float32)
probs /= probs.sum()
print("the probs are: ", probs, flush=True)
train_loader = OneHotLoader(
n_features=opts.n_features,
batch_size=opts.batch_size,
batches_per_epoch=opts.batches_per_epoch,
probs=probs,
)
# single batches with 1s on the diag
test_loader = UniformLoader(opts.n_features)
if opts.sender_cell == "transformer":
sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
sender = core.TransformerSenderReinforce(
agent=sender,
vocab_size=opts.vocab_size,
embed_dim=opts.sender_embedding,
max_len=opts.max_len,
num_layers=opts.sender_num_layers,
num_heads=opts.sender_num_heads,
hidden_size=opts.sender_hidden,
generate_style=opts.sender_generate_style,
causal=opts.causal_sender,
)
else:
sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)
sender = core.RnnSenderReinforce(
sender,
opts.vocab_size,
opts.sender_embedding,
opts.sender_hidden,
cell=opts.sender_cell,
max_len=opts.max_len,
num_layers=opts.sender_num_layers,
)
if opts.receiver_cell == "transformer":
receiver = Receiver(
n_features=opts.n_features, n_hidden=opts.receiver_embedding
)
receiver = core.TransformerReceiverDeterministic(
receiver,
opts.vocab_size,
opts.max_len,
opts.receiver_embedding,
opts.receiver_num_heads,
opts.receiver_hidden,
opts.receiver_num_layers,
causal=opts.causal_receiver,
)
else:
receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
receiver = core.RnnReceiverDeterministic(
receiver,
opts.vocab_size,
opts.receiver_embedding,
opts.receiver_hidden,
cell=opts.receiver_cell,
num_layers=opts.receiver_num_layers,
)
empty_logger = LoggingStrategy.minimal()
game = core.SenderReceiverRnnReinforce(
sender,
receiver,
loss,
sender_entropy_coeff=opts.sender_entropy_coeff,
receiver_entropy_coeff=opts.receiver_entropy_coeff,
train_logging_strategy=empty_logger,
length_cost=opts.length_cost,
)
optimizer = core.build_optimizer(game.parameters())
callbacks = [
EarlyStopperAccuracy(opts.early_stopping_thr),
core.ConsoleLogger(as_json=True, print_train_loss=True),
]
if opts.checkpoint_dir:
checkpoint_name = f"{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}"
callbacks.append(
core.CheckpointSaver(
checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name
)
)
trainer = core.Trainer(
game=game,
optimizer=optimizer,
train_data=train_loader,
validation_data=test_loader,
callbacks=callbacks,
)
trainer.train(n_epochs=opts.n_epochs)
game.logging_strategy = LoggingStrategy.maximal() # now log everything
dump(trainer.game, opts.n_features, device, False)
core.close()