neuron_explainer/activations/derived_scalars/edge_attribution.py [203:216]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                    )

                    def attribution_fn(
                        resid: torch.Tensor,
                        grad: torch.Tensor,
                        layer_index: LayerIndex,
                        pass_type: PassType,
                    ) -> torch.Tensor:
                        activation = activation_fn(resid, layer_index, pass_type)
                        assert activation.shape == grad.shape, (
                            activation.shape,
                            grad.shape,
                        )
                        return activation * grad
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



neuron_explainer/activations/derived_scalars/edge_attribution.py [250:263]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            )

            def attribution_fn(
                resid: torch.Tensor,
                grad: torch.Tensor,
                layer_index: LayerIndex,
                pass_type: PassType,
            ) -> torch.Tensor:
                activation = activation_fn(resid, layer_index, pass_type)
                assert activation.shape == grad.shape, (
                    activation.shape,
                    grad.shape,
                )
                return activation * grad
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



