in egg/zoo/language_bottleneck/intervention.py [0:0]
def intervention_message(self, game):
mean_acc = 0.0
scaler = 0.0
bob_label_mi = 0.0
corresponding_labels = []
original_messages = []
bob_inputs = []
alice_inputs = []
for batch in self.dataset:
if not isinstance(batch, Batch):
batch = Batch(*batch)
batch = batch.to(self.device)
sender_input, labels, receiver_input, aux_input = batch
original_message = game.sender(sender_input, aux_input)
# if Reinforce, agents return tuples
if not self.is_gs:
original_message = original_message[0]
if receiver_input is not None:
bob_inputs.extend(receiver_input)
alice_inputs.extend(sender_input)
permutation = torch.randperm(original_message.size(0)).to(
original_message.device
)
message = torch.index_select(original_message, 0, permutation)
output = game.receiver(message, receiver_input, aux_input)
if not self.is_gs:
output = output[0]
if not self.var_length:
_, rest = self.loss(None, None, None, output, labels, aux_input)
mean_acc += rest["acc"].mean().item()
scaler += 1
original_messages.extend(original_message)
elif not self.is_gs:
lengths = core.find_lengths(message)
for i in range(lengths.size(0)):
curr_len = lengths[i]
original_messages.append(message[i, :curr_len])
_, rest = self.loss(None, None, None, output, labels)
mean_acc += rest["acc"].mean().item()
scaler += 1
else:
message = message.argmax(dim=-1)
lengths = core.find_lengths(message)
for i in range(lengths.size(0)):
curr_len = lengths[i]
original_messages.append(message[i, :curr_len])
_, rest = self.loss(
None,
None,
None,
output[i : i + 1, curr_len - 1],
labels[i : i + 1],
)
mean_acc += rest["acc"].item()
scaler += 1
corresponding_labels.extend(labels)
label_entropy = entropy(corresponding_labels)
message_info = mutual_info(original_messages, corresponding_labels)
if bob_inputs:
bob_label_mi = mutual_info(bob_inputs, corresponding_labels)
alice_label_mi = mutual_info(alice_inputs, corresponding_labels)
mean_acc /= scaler
s = dict(
mean_acc=mean_acc,
label_entropy=label_entropy,
message_info=message_info,
bob_label_mi=bob_label_mi,
alice_label_mi=alice_label_mi,
)
return s