in run_exp.py [0:0]
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
model: torch.nn.Module, buffers: Buffers, flags):
try:
logging.info('Actor %i started.', i)
timings = prof.Timings() # Keep track of how fast things are.
gym_env = Net.create_env(flags)
seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
gym_env.seed(seed)
env = environment.Environment(gym_env)
env_output = env.initial()
agent_output = model(env_output)
while True:
index = free_queue.get()
if index is None:
break
# Write old rollout end.
for key in env_output:
buffers[key][index][0, ...] = env_output[key]
for key in agent_output:
buffers[key][index][0, ...] = agent_output[key]
# Do new rollout
for t in range(flags.unroll_length):
timings.reset()
with torch.no_grad():
agent_output = model(env_output)
timings.time('model')
env_output = env.step(agent_output['action'])
timings.time('step')
for key in env_output:
buffers[key][index][t + 1, ...] = env_output[key]
for key in agent_output:
buffers[key][index][t + 1, ...] = agent_output[key]
timings.time('write')
full_queue.put(index)
if i == 0:
logging.info('Actor %i: %s', i, timings.summary())
except KeyboardInterrupt:
pass # Return silently.
except Exception as e:
logging.error('Exception in worker process %i', i)
traceback.print_exc()
print()
raise e