in monobeast/minigrid/monobeast_amigo.py [0:0]
def act(
actor_index: int,
free_queue: mp.SimpleQueue,
full_queue: mp.SimpleQueue,
model: torch.nn.Module,
generator_model,
buffers: Buffers,
initial_agent_state_buffers, flags):
"""Defines and generates IMPALA actors in multiples threads."""
try:
logging.info("Actor %i started.", actor_index)
timings = prof.Timings() # Keep track of how fast things are.
gym_env = create_env(flags)
seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
gym_env.seed(seed)
#gym_env = wrappers.FullyObsWrapper(gym_env)
if flags.num_input_frames > 1:
gym_env = FrameStack(gym_env, flags.num_input_frames)
env = Observation_WrapperSetup(gym_env, fix_seed=flags.fix_seed, env_seed=flags.env_seed)
env_output = env.initial()
initial_frame = env_output['frame']
agent_state = model.initial_state(batch_size=1)
generator_output = generator_model(env_output)
goal = generator_output["goal"]
agent_output, unused_state = model(env_output, agent_state, goal)
while True:
index = free_queue.get()
if index is None:
break
# Write old rollout end.
for key in env_output:
buffers[key][index][0, ...] = env_output[key]
for key in agent_output:
buffers[key][index][0, ...] = agent_output[key]
for key in generator_output:
buffers[key][index][0, ...] = generator_output[key]
buffers["initial_frame"][index][0, ...] = initial_frame
for i, tensor in enumerate(agent_state):
initial_agent_state_buffers[index][i][...] = tensor
# Do new rollout
for t in range(flags.unroll_length):
aux_steps = 0
timings.reset()
if flags.modify:
new_frame = torch.flatten(env_output['frame'], 2, 3)
old_frame = torch.flatten(initial_frame, 2, 3)
ans = new_frame == old_frame
ans = torch.sum(ans, 3) != 3 # Reached if the three elements of the frame are not the same.
reached_condition = torch.squeeze(torch.gather(ans, 2, torch.unsqueeze(goal.long(),2)))
else:
agent_location = torch.flatten(env_output['frame'], 2, 3)
agent_location = agent_location[:,:,:,0]
agent_location = (agent_location == 10).nonzero() # select object id
agent_location = agent_location[:,2]
agent_location = agent_location.view(agent_output["action"].shape)
reached_condition = goal == agent_location
if reached_condition: # Generate new goal when reached intrinsic goal
if flags.restart_episode:
env_output = env.initial()
else:
env.episode_step = 0
initial_frame = env_output['frame']
with torch.no_grad():
generator_output = generator_model(env_output)
goal = generator_output["goal"]
if env_output['done'][0] == 1: # Generate a New Goal when episode finished
initial_frame = env_output['frame']
with torch.no_grad():
generator_output = generator_model(env_output)
goal = generator_output["goal"]
with torch.no_grad():
agent_output, agent_state = model(env_output, agent_state, goal)
timings.time("model")
env_output = env.step(agent_output["action"])
timings.time("step")
for key in env_output:
buffers[key][index][t + 1, ...] = env_output[key]
for key in agent_output:
buffers[key][index][t + 1, ...] = agent_output[key]
for key in generator_output:
buffers[key][index][t + 1, ...] = generator_output[key]
buffers["initial_frame"][index][t + 1, ...] = initial_frame
timings.time("write")
full_queue.put(index)
if actor_index == 0:
logging.info("Actor %i: %s", actor_index, timings.summary())
except KeyboardInterrupt:
pass # Return silently.
except Exception as e:
logging.error("Exception in worker process %i", actor_index)
traceback.print_exc()
print()
raise e