in neuron_explainer/activations/derived_scalars/activations_and_metadata.py [0:0]
def __eq__(self, other: Any) -> bool:
"""
Note that this uses torch.allclose, rather than checking for precise equality.
This permits ActivationsAndMetadata to be "equal" while having different dst
and pass type. This is useful for situations where we want to compare two derived scalars
that should be the same but that are computed in different ways
"""
if not isinstance(other, ActivationsAndMetadata):
return False
def check_activations_by_layer_index_equality(
self_activations_by_layer_index: dict[LayerIndex, torch.Tensor],
other_activations_by_layer_index: dict[LayerIndex, torch.Tensor],
) -> bool:
# check indices
if set(self_activations_by_layer_index.keys()) != set(
other_activations_by_layer_index.keys()
):
return False
# check shapes and then values
for layer_index in self_activations_by_layer_index.keys():
if (
self_activations_by_layer_index[layer_index].shape
!= other_activations_by_layer_index[layer_index].shape
):
return False
if not torch.allclose(
self_activations_by_layer_index[layer_index],
other_activations_by_layer_index[layer_index],
):
return False
return True
if not check_activations_by_layer_index_equality(
self.activations_by_layer_index, other.activations_by_layer_index
):
return False
return True