in understanding_rl_vision/rl_clarity/interface.py [0:0]
def save_attr(attr, attr_res, *, type_, data):
if attr_res is None:
attr_res = np.zeros_like(attr).sum(-1, keepdims=True)
filename_root = f"{squash(layer_key)}_{index}_{type_}"
if data is not None:
filename_root = f"{filename_root}_{data}"
if layer_name is not None:
channels_filename = f"{filename_root}_channels.json"
residuals_filename = f"{filename_root}_residuals.json"
channels_path = os.path.join(
attribution_totals_subdir, channels_filename
)
residuals_path = os.path.join(
attribution_totals_subdir, residuals_filename
)
save_json(attr.sum(-2).sum(-2), channels_path)
save_json(attr_res[..., 0].sum(-1).sum(-1), residuals_path)
totals["channels"].append(channels_filename)
totals["residuals"].append(residuals_filename)
totals["metadata"]["type"].append(type_)
totals["metadata"]["data"].append(data)
attr_scale = np.median(attr.max(axis=(-3, -2, -1)))
if attr_scale == 0:
attr_scale = attr.max()
if attr_scale == 0:
attr_scale = 1
attr_scaled = attr / attr_scale
attr_res_scaled = attr_res / attr_scale
channels = ["prin", "all"]
if attr_single_channels and layer_name is not None:
channels += list(range(nmf.features)) + ["res"]
for direction in ["abs", "pos", "neg"]:
if direction == "abs":
attr = np.abs(attr_scaled)
attr_res = np.abs(attr_res_scaled)
elif direction == "pos":
attr = np.maximum(attr_scaled, 0)
attr_res = np.maximum(attr_res_scaled, 0)
elif direction == "neg":
attr = np.maximum(-attr_scaled, 0)
attr_res = np.maximum(-attr_res_scaled, 0)
for channel in channels:
if isinstance(channel, int):
attr_single = attr.copy()
attr_single[..., :channel] = 0
attr_single[..., (channel + 1) :] = 0
images = channels_to_rgb(attr_single)
elif channel == "res":
images = attr_res.repeat(3, axis=-1)
else:
images = channels_to_rgb(attr)
if channel == "all":
images += attr_res.repeat(3, axis=-1)
images = brightness_to_opacity(
conv2d(images, filter_=norm_filter(15))
)
suffix = f"{direction}_{channel}"
images_filename = f"{filename_root}_{suffix}.png"
images_path = os.path.join(attribution_subdir, images_filename)
save_images(images, images_path)
scrub = images[:, :, get_scrub_slice(images.shape[2]), :]
scrub_path = os.path.join(
attribution_scrub_subdir, images_filename
)
save_images(scrub, scrub_path)
attribution["images"].append(images_filename)
attribution["metadata"]["type"].append(type_)
attribution["metadata"]["data"].append(data)
attribution["metadata"]["direction"].append(direction)
attribution["metadata"]["channel"].append(channel)