in run_exp.py [0:0]
def get_batch(free_queue: mp.SimpleQueue,
full_queue: mp.SimpleQueue,
buffers: Buffers,
flags,
timings,
lock=threading.Lock()) -> typing.Dict[str, torch.Tensor]:
with lock:
timings.time('lock')
indices = [full_queue.get() for _ in range(flags.batch_size)]
timings.time('dequeue')
batch = {
key: torch.stack([buffers[key][m] for m in indices], dim=1)
for key in buffers
}
timings.time('batch')
for m in indices:
free_queue.put(m)
timings.time('enqueue')
batch = {
k: t.to(device=flags.device, non_blocking=True)
for k, t in batch.items()
}
timings.time('device')
return batch