def batch_parallelize()

in sparse_autoencoder/explanations.py [0:0]


def batch_parallelize(algos, fn, batch_size):
    """
    Algorithms are coroutines that yield items to be processed in parallel.
    We concurrently run the algorithm on all items in the batch.
    """
    inputs = []
    for i, algo in enumerate(algos):
        inputs.append((i, next(algo)))
    results = [None] * len(algos)
    while len(inputs) > 0:
        ret = list(apply_batched(fn, [x[1] for x in inputs], batch_size))
        assert len(ret) == len(inputs)
        inds = [x[0] for x in inputs]
        inputs = []
        for i, r in zip(inds, ret):
            try:
                next_input = algos[i].send(r)
                inputs.append((i, next_input))
            except StopIteration as e:
                results[i] = e.value
    return results