in egg/zoo/language_bottleneck/intervention.py [0:0]
def validation(self, game):
interactions = core.dump_interactions(
game,
self.dataset,
gs=self.is_gs,
device=self.device,
variable_length=self.var_length,
)
messages = [interactions.message[i] for i in range(interactions.size)]
entropy_messages = entropy(messages)
labels = [interactions.labels[i] for i in range(interactions.size)]
message_mapping = {}
for message, label in zip(messages, labels):
message = _hashable_tensor(message)
label = _hashable_tensor(label)
if message not in message_mapping:
message_mapping[message] = {}
message_mapping[message][label] = message_mapping[message].get(label, 0) + 1
# majority vote per message
correct = 0.0
total = 0.0
for labels in message_mapping.values():
best_freq = None
for freq in labels.values():
if best_freq is None or freq > best_freq:
best_freq = freq
total += freq
correct += best_freq
majority_accuracy = correct / total
return dict(codewords_entropy=entropy_messages, majority_acc=majority_accuracy)