def act()

in monobeast/minigrid/monobeast_amigo.py [0:0]


def act(
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    generator_model,
    buffers: Buffers,
    initial_agent_state_buffers, flags):
    """Defines and generates IMPALA actors in multiples threads."""

    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.
        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        #gym_env = wrappers.FullyObsWrapper(gym_env)

        if flags.num_input_frames > 1:
            gym_env = FrameStack(gym_env, flags.num_input_frames)
        
        
        env = Observation_WrapperSetup(gym_env, fix_seed=flags.fix_seed, env_seed=flags.env_seed)
        env_output = env.initial()
        initial_frame = env_output['frame']

        
        agent_state = model.initial_state(batch_size=1)
        generator_output = generator_model(env_output)
        goal = generator_output["goal"]
        
        agent_output, unused_state = model(env_output, agent_state, goal)
        while True:
            index = free_queue.get()
            if index is None:
                break
            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for key in generator_output:
                buffers[key][index][0, ...] = generator_output[key]   
            buffers["initial_frame"][index][0, ...] = initial_frame     
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout
            for t in range(flags.unroll_length):
                aux_steps = 0
                timings.reset()

                

                if flags.modify:
                    new_frame = torch.flatten(env_output['frame'], 2, 3)
                    old_frame = torch.flatten(initial_frame, 2, 3)
                    ans = new_frame == old_frame
                    ans = torch.sum(ans, 3) != 3  # Reached if the three elements of the frame are not the same.
                    reached_condition = torch.squeeze(torch.gather(ans, 2, torch.unsqueeze(goal.long(),2)))
                    
                else:
                    agent_location = torch.flatten(env_output['frame'], 2, 3)
                    agent_location = agent_location[:,:,:,0] 
                    agent_location = (agent_location == 10).nonzero() # select object id
                    agent_location = agent_location[:,2]
                    agent_location = agent_location.view(agent_output["action"].shape)
                    reached_condition = goal == agent_location
           

                if reached_condition:   # Generate new goal when reached intrinsic goal
                    if flags.restart_episode:
                        env_output = env.initial() 
                    else:
                        env.episode_step = 0    
                    initial_frame = env_output['frame']
                    with torch.no_grad():
                        generator_output = generator_model(env_output)
                    goal = generator_output["goal"]

                if env_output['done'][0] == 1:  # Generate a New Goal when episode finished
                    initial_frame = env_output['frame']
                    with torch.no_grad():
                        generator_output = generator_model(env_output)
                    goal = generator_output["goal"]


                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state, goal)
                
                    
                timings.time("model")

                env_output = env.step(agent_output["action"])

                    
                    

                timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]
                for key in generator_output:
                    buffers[key][index][t + 1, ...] = generator_output[key]  
                buffers["initial_frame"][index][t + 1, ...] = initial_frame     

                

                timings.time("write")
            full_queue.put(index)
        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e