def intervention_message()

in egg/zoo/language_bottleneck/intervention.py [0:0]


    def intervention_message(self, game):
        mean_acc = 0.0
        scaler = 0.0

        bob_label_mi = 0.0

        corresponding_labels = []
        original_messages = []
        bob_inputs = []
        alice_inputs = []

        for batch in self.dataset:
            if not isinstance(batch, Batch):
                batch = Batch(*batch)
            batch = batch.to(self.device)
            sender_input, labels, receiver_input, aux_input = batch

            original_message = game.sender(sender_input, aux_input)
            # if Reinforce, agents return tuples
            if not self.is_gs:
                original_message = original_message[0]

            if receiver_input is not None:
                bob_inputs.extend(receiver_input)
            alice_inputs.extend(sender_input)

            permutation = torch.randperm(original_message.size(0)).to(
                original_message.device
            )
            message = torch.index_select(original_message, 0, permutation)
            output = game.receiver(message, receiver_input, aux_input)

            if not self.is_gs:
                output = output[0]

            if not self.var_length:
                _, rest = self.loss(None, None, None, output, labels, aux_input)
                mean_acc += rest["acc"].mean().item()
                scaler += 1

                original_messages.extend(original_message)
            elif not self.is_gs:
                lengths = core.find_lengths(message)

                for i in range(lengths.size(0)):
                    curr_len = lengths[i]
                    original_messages.append(message[i, :curr_len])

                _, rest = self.loss(None, None, None, output, labels)
                mean_acc += rest["acc"].mean().item()
                scaler += 1
            else:
                message = message.argmax(dim=-1)
                lengths = core.find_lengths(message)

                for i in range(lengths.size(0)):
                    curr_len = lengths[i]
                    original_messages.append(message[i, :curr_len])

                    _, rest = self.loss(
                        None,
                        None,
                        None,
                        output[i : i + 1, curr_len - 1],
                        labels[i : i + 1],
                    )
                    mean_acc += rest["acc"].item()
                    scaler += 1

            corresponding_labels.extend(labels)

        label_entropy = entropy(corresponding_labels)

        message_info = mutual_info(original_messages, corresponding_labels)
        if bob_inputs:
            bob_label_mi = mutual_info(bob_inputs, corresponding_labels)
        alice_label_mi = mutual_info(alice_inputs, corresponding_labels)

        mean_acc /= scaler

        s = dict(
            mean_acc=mean_acc,
            label_entropy=label_entropy,
            message_info=message_info,
            bob_label_mi=bob_label_mi,
            alice_label_mi=alice_label_mi,
        )

        return s