def save_attr()

in understanding_rl_vision/rl_clarity/interface.py [0:0]


            def save_attr(attr, attr_res, *, type_, data):
                if attr_res is None:
                    attr_res = np.zeros_like(attr).sum(-1, keepdims=True)
                filename_root = f"{squash(layer_key)}_{index}_{type_}"
                if data is not None:
                    filename_root = f"{filename_root}_{data}"
                if layer_name is not None:
                    channels_filename = f"{filename_root}_channels.json"
                    residuals_filename = f"{filename_root}_residuals.json"
                    channels_path = os.path.join(
                        attribution_totals_subdir, channels_filename
                    )
                    residuals_path = os.path.join(
                        attribution_totals_subdir, residuals_filename
                    )
                    save_json(attr.sum(-2).sum(-2), channels_path)
                    save_json(attr_res[..., 0].sum(-1).sum(-1), residuals_path)
                    totals["channels"].append(channels_filename)
                    totals["residuals"].append(residuals_filename)
                    totals["metadata"]["type"].append(type_)
                    totals["metadata"]["data"].append(data)
                attr_scale = np.median(attr.max(axis=(-3, -2, -1)))
                if attr_scale == 0:
                    attr_scale = attr.max()
                if attr_scale == 0:
                    attr_scale = 1
                attr_scaled = attr / attr_scale
                attr_res_scaled = attr_res / attr_scale
                channels = ["prin", "all"]
                if attr_single_channels and layer_name is not None:
                    channels += list(range(nmf.features)) + ["res"]
                for direction in ["abs", "pos", "neg"]:
                    if direction == "abs":
                        attr = np.abs(attr_scaled)
                        attr_res = np.abs(attr_res_scaled)
                    elif direction == "pos":
                        attr = np.maximum(attr_scaled, 0)
                        attr_res = np.maximum(attr_res_scaled, 0)
                    elif direction == "neg":
                        attr = np.maximum(-attr_scaled, 0)
                        attr_res = np.maximum(-attr_res_scaled, 0)
                    for channel in channels:
                        if isinstance(channel, int):
                            attr_single = attr.copy()
                            attr_single[..., :channel] = 0
                            attr_single[..., (channel + 1) :] = 0
                            images = channels_to_rgb(attr_single)
                        elif channel == "res":
                            images = attr_res.repeat(3, axis=-1)
                        else:
                            images = channels_to_rgb(attr)
                            if channel == "all":
                                images += attr_res.repeat(3, axis=-1)
                        images = brightness_to_opacity(
                            conv2d(images, filter_=norm_filter(15))
                        )
                        suffix = f"{direction}_{channel}"
                        images_filename = f"{filename_root}_{suffix}.png"
                        images_path = os.path.join(attribution_subdir, images_filename)
                        save_images(images, images_path)
                        scrub = images[:, :, get_scrub_slice(images.shape[2]), :]
                        scrub_path = os.path.join(
                            attribution_scrub_subdir, images_filename
                        )
                        save_images(scrub, scrub_path)
                        attribution["images"].append(images_filename)
                        attribution["metadata"]["type"].append(type_)
                        attribution["metadata"]["data"].append(data)
                        attribution["metadata"]["direction"].append(direction)
                        attribution["metadata"]["channel"].append(channel)