in lucid/scratch/rl_util/attribution.py [0:0]
def get_paths(acts, nmf, *, max_paths, integrate_steps):
acts_reduced = nmf.transform(acts)
residual = acts - nmf.inverse_transform(acts_reduced)
combs = itertools.combinations(range(nmf.features), nmf.features // 2)
if nmf.features % 2 == 0:
combs = np.array([comb for comb in combs if 0 in comb])
else:
combs = np.array(list(combs))
if max_paths is None:
splits = combs
else:
num_splits = min((max_paths + 1) // 2, combs.shape[0])
splits = combs[
np.random.choice(combs.shape[0], size=num_splits, replace=False), :
]
for i, split in enumerate(splits):
indices = np.zeros(nmf.features)
indices[split] = 1.0
indices = indices[tuple(None for _ in range(acts_reduced.ndim - 1))]
complements = [False, True]
if max_paths is not None and i * 2 + 1 == max_paths:
complements = [np.random.choice(complements)]
for complement in complements:
path = []
for alpha in np.linspace(0, 1, integrate_steps + 1)[1:]:
if complement:
coordinates = (1.0 - indices) * alpha ** 2 + indices * (
1.0 - (1.0 - alpha) ** 2
)
else:
coordinates = indices * alpha ** 2 + (1.0 - indices) * (
1.0 - (1.0 - alpha) ** 2
)
path.append(
nmf.inverse_transform(acts_reduced * coordinates) + residual * alpha
)
yield path