def intervention_input()

in egg/zoo/language_bottleneck/intervention.py [0:0]


    def intervention_input(self, game):
        mean_acc = 0.0
        scaler = 0.0

        for batch in self.dataset:
            if not isinstance(batch, Batch):
                batch = Batch(*batch)
            batch = batch.to(self.device)
            sender_input, labels, receiver_input, aux_input = batch

            message = game.sender(sender_input, aux_input)
            # if Reinforce, agents return tuples
            if not self.is_gs:
                message = message[0]

            permutation = torch.randperm(receiver_input.size(0)).to(message.device)
            receiver_input = torch.index_select(receiver_input, 0, permutation)
            output = game.receiver(message, receiver_input, aux_input)

            if not self.is_gs:
                output = output[0]
            if self.is_gs and self.var_length:
                message = message.argmax(dim=-1)
                lengths = core.find_lengths(message)

                for i in range(lengths.size(0)):
                    curr_len = lengths[i]
                    _, rest = self.loss(
                        None,
                        None,
                        None,
                        output[i : i + 1, curr_len - 1],
                        labels[i : i + 1],
                        aux_input,
                    )
                    mean_acc += rest["acc"].item()
                    scaler += 1
            else:
                _, rest = self.loss(None, None, None, output, labels, aux_input)
                mean_acc += rest["acc"].mean().item()
                scaler += 1.0

        mean_acc /= scaler
        s = dict(
            mean_acc=mean_acc,
        )

        return s