neuron_explainer/activations/derived_scalars/edge_attribution.py [337:378]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        def get_activation_from_layer_activations(layer_activations: torch.Tensor) -> torch.Tensor:
            return layer_activations[node_index.tensor_indices]

        return get_activation_from_layer_activations

    def make_gradient_scalar_deriver_for_node_index(
        self,
        node_index: NodeIndex,
        dst_config: DstConfig,
        output_dst: DerivedScalarType | None = None,
    ) -> ScalarDeriver:
        self._check_node_index(node_index)
        assert node_index.layer_index is not None
        dst_config_for_layer = dataclasses.replace(
            dst_config,
            layer_indices=[node_index.layer_index],
        )
        scalar_hook = self.make_scalar_hook_for_node_index(node_index)
        return self.make_gradient_scalar_deriver(
            scalar_hook=scalar_hook,
            dst_config=dst_config_for_layer,
            output_dst=output_dst,
        )

    def make_gradient_scalar_source_for_node_index(
        self,
        node_index: NodeIndex,
        dst_config: DstConfig,
        output_dst: DerivedScalarType | None = None,
    ) -> DerivedScalarSource:
        scalar_hook = self.make_scalar_hook_for_node_index(node_index)
        gradient_scalar_deriver = self.make_gradient_scalar_deriver(
            scalar_hook=scalar_hook,
            dst_config=dst_config,
            output_dst=output_dst,
        )
        assert node_index.layer_index is not None
        return DerivedScalarSource(
            scalar_deriver=gradient_scalar_deriver,
            pass_type=PassType.FORWARD,
            layer_indexer=ConstantLayerIndexer(node_index.layer_index),
        )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



