in src/gluonts/model/predictor.py [0:0]
def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
with TemporaryDirectory() as tempdir:
predictor_path = Path(tempdir)
self._base_predictor.serialize(predictor_path)
# TODO: Consider using shared memory for the data transfer.
self._input_queues = [mp.Queue() for _ in range(self._num_workers)]
self._output_queue = mp.Queue()
workers = []
for worker_id, in_q in enumerate(self._input_queues):
worker = mp.Process(
target=_worker_loop,
args=(predictor_path, in_q, self._output_queue, worker_id),
kwargs=kwargs,
)
worker.daemon = True
worker.start()
workers.append(worker)
self._num_running_workers += 1
self._workers = workers
chunked_data = self._grouper(dataset, self._chunk_size)
self._send_idx = 0
self._next_idx = 0
self._data_buffer = {}
worker_ids = list(range(self._num_workers))
def receive():
idx, worker_id, result = self._output_queue.get()
if isinstance(idx, WorkerError):
self._num_running_workers -= 1
self.terminate()
raise Exception(idx.msg)
if idx is not None:
self._data_buffer[idx] = result
return idx, worker_id, result
def get_next_from_buffer():
while self._next_idx in self._data_buffer:
result_batch = self._data_buffer.pop(self._next_idx)
self._next_idx += 1
for result in result_batch:
yield result
def send(worker_id, chunk):
q = self._input_queues[worker_id]
q.put((self._send_idx, chunk))
self._send_idx += 1
try:
# prime the queues
for wid in worker_ids:
chunk = next(chunked_data)
send(wid, chunk)
while True:
idx, wid, result = receive()
for res in get_next_from_buffer():
yield res
chunk = next(chunked_data)
send(wid, chunk)
except StopIteration:
# signal workers end of data
for q in self._input_queues:
q.put((None, None))
# collect any outstanding results
while self._num_running_workers > 0:
idx, worker_id, result = receive()
if idx is None:
self._num_running_workers -= 1
continue
for res in get_next_from_buffer():
yield res
assert len(self._data_buffer) == 0
assert self._send_idx == self._next_idx