def transform_attr()

in understanding_rl_vision/rl_clarity/interface.py [0:0]


            def transform_attr(attr):
                if layer_name is None:
                    return attr, None
                else:
                    attr_trans = nmf.transform(np.maximum(attr, 0)) - nmf.transform(
                        np.maximum(-attr, 0)
                    )
                    attr_res = (
                        attr
                        - (
                            nmf.inverse_transform(np.maximum(attr_trans, 0))
                            - nmf.inverse_transform(np.maximum(-attr_trans, 0))
                        )
                    ).sum(-1, keepdims=True)
                    nmf_norms = nmf.channel_dirs.sum(-1)
                    return attr_trans * nmf_norms[None, None, None], attr_res