in lucid/scratch/rl_util/nmf.py [0:0]
def argmax_nd(x, axes, *, max_rep=np.inf, max_rep_strict=None):
assert max_rep > 0
assert np.isinf(max_rep) or max_rep_strict is not None
perm = list(range(len(x.shape)))
for axis in reversed(axes):
loc = perm.index(axis)
perm = [axis] + perm[:loc] + perm[loc + 1 :]
x = x.transpose(perm)
shape = x.shape
axes_size = reduce(lambda a, b: a * b, shape[: len(axes)], 1)
x = x.reshape([axes_size, -1])
indices = np.argsort(-x, axis=0)
result = indices[0].copy()
counts = np.zeros(len(indices), dtype=int)
unique_values, unique_counts = np.unique(result, return_counts=True)
counts[unique_values] = unique_counts
for i in range(1, len(indices) + (0 if max_rep_strict else 1)):
order = np.argsort(x[result, range(len(result))])
result_in_order = result[order]
current_counts = counts.copy()
changed = False
for j in range(len(order)):
value = result_in_order[j]
if current_counts[value] > max_rep:
pos = order[j]
new_value = indices[i % len(indices)][pos]
result[pos] = new_value
current_counts[value] -= 1
counts[value] -= 1
counts[new_value] += 1
changed = True
if not changed:
break
result = result.reshape(shape[len(axes) :])
return np.unravel_index(result, shape[: len(axes)])