in egg/zoo/language_bottleneck/intervention.py [0:0]
def intervention_input(self, game):
mean_acc = 0.0
scaler = 0.0
for batch in self.dataset:
if not isinstance(batch, Batch):
batch = Batch(*batch)
batch = batch.to(self.device)
sender_input, labels, receiver_input, aux_input = batch
message = game.sender(sender_input, aux_input)
# if Reinforce, agents return tuples
if not self.is_gs:
message = message[0]
permutation = torch.randperm(receiver_input.size(0)).to(message.device)
receiver_input = torch.index_select(receiver_input, 0, permutation)
output = game.receiver(message, receiver_input, aux_input)
if not self.is_gs:
output = output[0]
if self.is_gs and self.var_length:
message = message.argmax(dim=-1)
lengths = core.find_lengths(message)
for i in range(lengths.size(0)):
curr_len = lengths[i]
_, rest = self.loss(
None,
None,
None,
output[i : i + 1, curr_len - 1],
labels[i : i + 1],
aux_input,
)
mean_acc += rest["acc"].item()
scaler += 1
else:
_, rest = self.loss(None, None, None, output, labels, aux_input)
mean_acc += rest["acc"].mean().item()
scaler += 1.0
mean_acc /= scaler
s = dict(
mean_acc=mean_acc,
)
return s