in egg/zoo/compo_vs_generalization/train.py [0:0]
def main(params):
import copy
opts = get_params(params)
device = opts.device
full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
if opts.density_data > 0:
sampled_data = select_subset_V2(
full_data, opts.density_data, opts.n_attributes, opts.n_values
)
full_data = copy.deepcopy(sampled_data)
train, generalization_holdout = split_holdout(full_data)
train, uniform_holdout = split_train_test(train, 0.1)
generalization_holdout, train, uniform_holdout, full_data = [
one_hotify(x, opts.n_attributes, opts.n_values)
for x in [generalization_holdout, train, uniform_holdout, full_data]
]
train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(train, 1)
generalization_holdout, uniform_holdout, full_data = (
ScaledDataset(generalization_holdout),
ScaledDataset(uniform_holdout),
ScaledDataset(full_data),
)
generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [
DataLoader(x, batch_size=opts.batch_size)
for x in [generalization_holdout, uniform_holdout, full_data]
]
train_loader = DataLoader(train, batch_size=opts.batch_size)
validation_loader = DataLoader(validation, batch_size=len(validation))
n_dim = opts.n_attributes * opts.n_values
if opts.receiver_cell in ["lstm", "rnn", "gru"]:
receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
receiver = core.RnnReceiverDeterministic(
receiver,
opts.vocab_size + 1,
opts.receiver_emb,
opts.receiver_hidden,
cell=opts.receiver_cell,
)
else:
raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}")
if opts.sender_cell in ["lstm", "rnn", "gru"]:
sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
sender = core.RnnSenderReinforce(
agent=sender,
vocab_size=opts.vocab_size,
embed_dim=opts.sender_emb,
hidden_size=opts.sender_hidden,
max_len=opts.max_len,
cell=opts.sender_cell,
)
else:
raise ValueError(f"Unknown sender cell, {opts.sender_cell}")
sender = PlusOneWrapper(sender)
loss = DiffLoss(opts.n_attributes, opts.n_values)
baseline = {
"no": core.baselines.NoBaseline,
"mean": core.baselines.MeanBaseline,
"builtin": core.baselines.BuiltInBaseline,
}[opts.baseline]
game = core.SenderReceiverRnnReinforce(
sender,
receiver,
loss,
sender_entropy_coeff=opts.sender_entropy_coeff,
receiver_entropy_coeff=0.0,
length_cost=0.0,
baseline_type=baseline,
)
optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)
metrics_evaluator = Metrics(
validation.examples,
opts.device,
opts.n_attributes,
opts.n_values,
opts.vocab_size + 1,
freq=opts.stats_freq,
)
loaders = []
loaders.append(
(
"generalization hold out",
generalization_holdout_loader,
DiffLoss(opts.n_attributes, opts.n_values, generalization=True),
)
)
loaders.append(
(
"uniform holdout",
uniform_holdout_loader,
DiffLoss(opts.n_attributes, opts.n_values),
)
)
holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True)
trainer = core.Trainer(
game=game,
optimizer=optimizer,
train_data=train_loader,
validation_data=validation_loader,
callbacks=[
core.ConsoleLogger(as_json=True, print_train_loss=False),
early_stopper,
metrics_evaluator,
holdout_evaluator,
],
)
trainer.train(n_epochs=opts.n_epochs)
last_epoch_interaction = early_stopper.validation_stats[-1][1]
validation_acc = last_epoch_interaction.aux["acc"].mean()
uniformtest_acc = holdout_evaluator.results["uniform holdout"]["acc"]
# Train new agents
if validation_acc > 0.99:
def _set_seed(seed):
import random
import numpy as np
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
core.get_opts().preemptable = False
core.get_opts().checkpoint_path = None
# freeze Sender and probe how fast a simple Receiver will learn the thing
def retrain_receiver(receiver_generator, sender):
receiver = receiver_generator()
game = core.SenderReceiverRnnReinforce(
sender,
receiver,
loss,
sender_entropy_coeff=0.0,
receiver_entropy_coeff=0.0,
)
optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
early_stopper = EarlyStopperAccuracy(
opts.early_stopping_thr, validation=True
)
trainer = core.Trainer(
game=game,
optimizer=optimizer,
train_data=train_loader,
validation_data=validation_loader,
callbacks=[early_stopper, Evaluator(loaders, opts.device, freq=0)],
)
trainer.train(n_epochs=opts.n_epochs // 2)
accs = [x[1]["acc"] for x in early_stopper.validation_stats]
return accs
frozen_sender = Freezer(copy.deepcopy(sender))
def gru_receiver_generator():
return core.RnnReceiverDeterministic(
Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
opts.vocab_size + 1,
opts.receiver_emb,
hidden_size=opts.receiver_hidden,
cell="gru",
)
def small_gru_receiver_generator():
return core.RnnReceiverDeterministic(
Receiver(n_hidden=100, n_outputs=n_dim),
opts.vocab_size + 1,
opts.receiver_emb,
hidden_size=100,
cell="gru",
)
def tiny_gru_receiver_generator():
return core.RnnReceiverDeterministic(
Receiver(n_hidden=50, n_outputs=n_dim),
opts.vocab_size + 1,
opts.receiver_emb,
hidden_size=50,
cell="gru",
)
def nonlinear_receiver_generator():
return NonLinearReceiver(
n_outputs=n_dim,
vocab_size=opts.vocab_size + 1,
max_length=opts.max_len,
n_hidden=opts.receiver_hidden,
)
for name, receiver_generator in [
("gru", gru_receiver_generator),
("nonlinear", nonlinear_receiver_generator),
("tiny_gru", tiny_gru_receiver_generator),
("small_gru", small_gru_receiver_generator),
]:
for seed in range(17, 17 + 3):
_set_seed(seed)
accs = retrain_receiver(receiver_generator, frozen_sender)
accs += [1.0] * (opts.n_epochs // 2 - len(accs))
auc = sum(accs)
print(
json.dumps(
{
"mode": "reset",
"seed": seed,
"receiver_name": name,
"auc": auc,
}
)
)
print("---End--")
core.close()