def __eq__()

in neuron_explainer/activations/derived_scalars/activations_and_metadata.py [0:0]


    def __eq__(self, other: Any) -> bool:
        """
        Note that this uses torch.allclose, rather than checking for precise equality.

        This permits ActivationsAndMetadata to be "equal" while having different dst
        and pass type. This is useful for situations where we want to compare two derived scalars
        that should be the same but that are computed in different ways
        """
        if not isinstance(other, ActivationsAndMetadata):
            return False

        def check_activations_by_layer_index_equality(
            self_activations_by_layer_index: dict[LayerIndex, torch.Tensor],
            other_activations_by_layer_index: dict[LayerIndex, torch.Tensor],
        ) -> bool:
            # check indices
            if set(self_activations_by_layer_index.keys()) != set(
                other_activations_by_layer_index.keys()
            ):
                return False
            # check shapes and then values
            for layer_index in self_activations_by_layer_index.keys():
                if (
                    self_activations_by_layer_index[layer_index].shape
                    != other_activations_by_layer_index[layer_index].shape
                ):
                    return False
                if not torch.allclose(
                    self_activations_by_layer_index[layer_index],
                    other_activations_by_layer_index[layer_index],
                ):
                    return False
            return True

        if not check_activations_by_layer_index_equality(
            self.activations_by_layer_index, other.activations_by_layer_index
        ):
            return False

        return True