def get_paths()

in lucid/scratch/rl_util/attribution.py [0:0]


def get_paths(acts, nmf, *, max_paths, integrate_steps):
    acts_reduced = nmf.transform(acts)
    residual = acts - nmf.inverse_transform(acts_reduced)
    combs = itertools.combinations(range(nmf.features), nmf.features // 2)
    if nmf.features % 2 == 0:
        combs = np.array([comb for comb in combs if 0 in comb])
    else:
        combs = np.array(list(combs))
    if max_paths is None:
        splits = combs
    else:
        num_splits = min((max_paths + 1) // 2, combs.shape[0])
        splits = combs[
            np.random.choice(combs.shape[0], size=num_splits, replace=False), :
        ]
    for i, split in enumerate(splits):
        indices = np.zeros(nmf.features)
        indices[split] = 1.0
        indices = indices[tuple(None for _ in range(acts_reduced.ndim - 1))]
        complements = [False, True]
        if max_paths is not None and i * 2 + 1 == max_paths:
            complements = [np.random.choice(complements)]
        for complement in complements:
            path = []
            for alpha in np.linspace(0, 1, integrate_steps + 1)[1:]:
                if complement:
                    coordinates = (1.0 - indices) * alpha ** 2 + indices * (
                        1.0 - (1.0 - alpha) ** 2
                    )
                else:
                    coordinates = indices * alpha ** 2 + (1.0 - indices) * (
                        1.0 - (1.0 - alpha) ** 2
                    )
                path.append(
                    nmf.inverse_transform(acts_reduced * coordinates) + residual * alpha
                )
            yield path