in monobeast/minigrid/monobeast_amigo.py [0:0]
def get_batch(
flags,
free_queue: mp.SimpleQueue,
full_queue: mp.SimpleQueue,
buffers: Buffers,
initial_agent_state_buffers,
timings,
lock=threading.Lock()):
"""Returns a Batch with the history."""
with lock:
timings.time("lock")
indices = [full_queue.get() for _ in range(flags.batch_size)]
timings.time("dequeue")
batch = {
key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers
}
initial_agent_state = (
torch.cat(ts, dim=1)
for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
)
timings.time("batch")
for m in indices:
free_queue.put(m)
timings.time("enqueue")
batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()}
initial_agent_state = tuple(t.to(device=flags.device, non_blocking=True)
for t in initial_agent_state)
timings.time("device")
return batch, initial_agent_state