in neuron_explainer/activations/derived_scalars/derived_scalar_store.py [0:0]
def __eq__(self, other: Any) -> bool:
"""
note that this uses torch.allclose, rather than checking for precise equality
"""
if not isinstance(other, DerivedScalarStore):
return False
dst_and_pts = self.activations_and_metadata_by_dst_and_pass_type.keys()
other_dst_and_pts = other.activations_and_metadata_by_dst_and_pass_type.keys()
if not set(dst_and_pts) == set(other_dst_and_pts):
return False
for dst_and_pt in dst_and_pts:
if (
not self.activations_and_metadata_by_dst_and_pass_type[dst_and_pt]
== other.activations_and_metadata_by_dst_and_pass_type[dst_and_pt]
):
return False
return True