in sparse_autoencoder/explanations.py [0:0]
def batch_parallelize(algos, fn, batch_size):
"""
Algorithms are coroutines that yield items to be processed in parallel.
We concurrently run the algorithm on all items in the batch.
"""
inputs = []
for i, algo in enumerate(algos):
inputs.append((i, next(algo)))
results = [None] * len(algos)
while len(inputs) > 0:
ret = list(apply_batched(fn, [x[1] for x in inputs], batch_size))
assert len(ret) == len(inputs)
inds = [x[0] for x in inputs]
inputs = []
for i, r in zip(inds, ret):
try:
next_input = algos[i].send(r)
inputs.append((i, next_input))
except StopIteration as e:
results[i] = e.value
return results