in egg/zoo/objects_game/train.py [0:0]
def main(params):
opts = get_params(params)
device = torch.device("cuda" if opts.cuda else "cpu")
data_loader = VectorsLoader(
perceptual_dimensions=opts.perceptual_dimensions,
n_distractors=opts.n_distractors,
batch_size=opts.batch_size,
train_samples=opts.train_samples,
validation_samples=opts.validation_samples,
test_samples=opts.test_samples,
shuffle_train_data=opts.shuffle_train_data,
dump_data_folder=opts.dump_data_folder,
load_data_path=opts.load_data_path,
seed=opts.data_seed,
)
train_data, validation_data, test_data = data_loader.get_iterators()
data_loader.upd_cl_options(opts)
if opts.max_len > 1:
baseline_msg = 'Cannot yet compute "smart" baseline value for messages of length greater than 1'
else:
baseline_msg = (
f"\n| Baselines measures with {opts.n_distractors} distractors and messages of max_len = {opts.max_len}:\n"
f"| Dummy random baseline: accuracy = {1 / (opts.n_distractors + 1)}\n"
)
if -1 not in opts.perceptual_dimensions:
baseline_msg += f'| "Smart" baseline with perceptual_dimensions {opts.perceptual_dimensions} = {compute_baseline_accuracy(opts.n_distractors, opts.max_len, *opts.perceptual_dimensions)}\n'
else:
baseline_msg += f'| Data was loaded froman external file, thus no perceptual_dimension vector was provided, "smart baseline" cannot be computed\n'
print(baseline_msg)
sender = Sender(n_features=data_loader.n_features, n_hidden=opts.sender_hidden)
receiver = Receiver(
n_features=data_loader.n_features, linear_units=opts.receiver_hidden
)
if opts.mode.lower() == "gs":
sender = core.RnnSenderGS(
sender,
opts.vocab_size,
opts.sender_embedding,
opts.sender_hidden,
cell=opts.sender_cell,
max_len=opts.max_len,
temperature=opts.temperature,
)
receiver = core.RnnReceiverGS(
receiver,
opts.vocab_size,
opts.receiver_embedding,
opts.receiver_hidden,
cell=opts.receiver_cell,
)
game = core.SenderReceiverRnnGS(sender, receiver, loss)
else:
raise NotImplementedError(f"Unknown training mode, {opts.mode}")
optimizer = torch.optim.Adam(
[
{"params": game.sender.parameters(), "lr": opts.sender_lr},
{"params": game.receiver.parameters(), "lr": opts.receiver_lr},
]
)
callbacks = [core.ConsoleLogger(as_json=True)]
if opts.mode.lower() == "gs":
callbacks.append(core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1))
trainer = core.Trainer(
game=game,
optimizer=optimizer,
train_data=train_data,
validation_data=validation_data,
callbacks=callbacks,
)
trainer.train(n_epochs=opts.n_epochs)
if opts.evaluate:
is_gs = opts.mode == "gs"
(
sender_inputs,
messages,
receiver_inputs,
receiver_outputs,
labels,
) = dump_sender_receiver(
game, test_data, is_gs, variable_length=True, device=device
)
receiver_outputs = move_to(receiver_outputs, device)
labels = move_to(labels, device)
receiver_outputs = torch.stack(receiver_outputs)
labels = torch.stack(labels)
tensor_accuracy = receiver_outputs.argmax(dim=1) == labels
accuracy = torch.mean(tensor_accuracy.float()).item()
unique_dict = {}
for elem in sender_inputs:
target = ""
for dim in elem:
target += f"{str(int(dim.item()))}-"
target = target[:-1]
if target not in unique_dict:
unique_dict[target] = True
print(f"| Accuracy on test set: {accuracy}")
compute_mi_input_msgs(sender_inputs, messages)
print(f"entropy sender inputs {entropy(sender_inputs)}")
print(f"mi sender inputs msgs {mutual_info(sender_inputs, messages)}")
if opts.dump_msg_folder:
opts.dump_msg_folder.mkdir(exist_ok=True)
msg_dict = {}
output_msg = (
f"messages_{opts.perceptual_dimensions}_vocab_{opts.vocab_size}"
f"_maxlen_{opts.max_len}_bsize_{opts.batch_size}"
f"_n_distractors_{opts.n_distractors}_train_size_{opts.train_samples}"
f"_valid_size_{opts.validation_samples}_test_size_{opts.test_samples}"
f"_slr_{opts.sender_lr}_rlr_{opts.receiver_lr}_shidden_{opts.sender_hidden}"
f"_rhidden_{opts.receiver_hidden}_semb_{opts.sender_embedding}"
f"_remb_{opts.receiver_embedding}_mode_{opts.mode}"
f"_scell_{opts.sender_cell}_rcell_{opts.receiver_cell}.msg"
)
output_file = opts.dump_msg_folder / output_msg
with open(output_file, "w") as f:
f.write(f"{opts}\n")
for (
sender_input,
message,
receiver_input,
receiver_output,
label,
) in zip(
sender_inputs, messages, receiver_inputs, receiver_outputs, labels
):
sender_input = ",".join(map(str, sender_input.tolist()))
message = ",".join(map(str, message.tolist()))
distractors_list = receiver_input.tolist()
receiver_input = "; ".join(
[",".join(map(str, elem)) for elem in distractors_list]
)
if is_gs:
receiver_output = receiver_output.argmax()
f.write(
f"{sender_input} -> {receiver_input} -> {message} -> {receiver_output} (label={label.item()})\n"
)
if message in msg_dict:
msg_dict[message] += 1
else:
msg_dict[message] = 1
sorted_msgs = sorted(
msg_dict.items(), key=operator.itemgetter(1), reverse=True
)
f.write(
f"\nUnique target vectors seen by sender: {len(unique_dict.keys())}\n"
)
f.write(f"Unique messages produced by sender: {len(msg_dict.keys())}\n")
f.write(f"Messagses: 'msg' : msg_count: {str(sorted_msgs)}\n")
f.write(f"\nAccuracy: {accuracy}")