def act()

in source/MXNetEnv/training/training_src/networks/agent.py [0:0]


    def act(self, state, snake_id, turn_count, snake_health, episode, eps=0.):
        """Returns actions for given state as per current policy.
        
        Params
        ======
            state (array_like): current state
            snake_id (int): ID of the snake
            turn_count (int): turn count of the game
            snake_health (int): health of the snake
            episode (int): The current episode, used for checking previous states
            eps (float): epsilon, for epsilon-greedy action selection
        """
        # Epsilon-greedy action selection
        if random.random() > eps:
            snake_health = snake_health - 1 # Account for taking the current move

            empty_state = np.zeros(state.shape)
            turn_count_eos = -1
            snake_health_eos = 101

            with autograd.predict_mode():
                last_n_memory = self.memory.get_last_n(n=self.sequence_length - 1)
                state_sequence, snake_id_sequence = [], []
                turn_count_sequence, snake_health_sequence = [], []
                for i in range(self.sequence_length):
                    if i == self.sequence_length - 1:
                        turn_count_i = turn_count
                        episode_i = episode
                        delta = 0
                        state_i = state
                        snake_health_i = snake_health
                    else:
                        turn_count_i = last_n_memory[i].turn_count
                        episode_i = last_n_memory[i].episode
                        delta = self.sequence_length - 1 - i
                        state_i =  last_n_memory[i].state
                        snake_health_i = last_n_memory[i].snake_health

                    episode_correct = episode_i == episode
                    turn_correct = turn_count_i + delta == turn_count
                    if episode_correct and turn_correct:
                        state_sequence.append(state_i)
                        turn_count_sequence.append(turn_count_i)
                        snake_health_sequence.append(snake_health_i)
                    else:
                        state_sequence.append(empty_state)
                        turn_count_sequence.append(turn_count_eos)
                        snake_health_sequence.append(snake_health_eos)

                state_sequence = mx.nd.array(np.stack(state_sequence),
                                             ctx=ctx).transpose((0, 3, 1, 2)).expand_dims(0)
                turn_count_sequence = mx.nd.array(np.stack(turn_count_sequence),
                                                  ctx=ctx).expand_dims(0)
                snake_health_sequence = mx.nd.array(np.stack(snake_health_sequence),
                                                  ctx=ctx).expand_dims(0)

                snake_id_sequence = mx.nd.array(np.array([snake_id]*self.sequence_length),
                                                ctx=ctx).expand_dims(0)
                
                if self.qnetwork_local.take_additional_forward_arguments:
                    action_values = self.qnetwork_local(state_sequence,
                                                        snake_id_sequence,
                                                        turn_count_sequence,
                                                        snake_health_sequence)
                else:
                    action_values = self.qnetwork_local(state_sequence)

            return np.argmax(action_values.asnumpy())
        else:
            last_memory = self.memory.get_last_n(n=1)[0]
            if last_memory is not None:
                # Disable choosing random actions of forbiden moves
                last_action = last_memory.action
                last_episode = last_memory.episode
                last_turn_count = last_memory.turn_count
                if last_episode == episode and (last_turn_count == turn_count - 1):
                    # Check that last_memory is from the same episode
                    action_space = [*range(self.action_size)]
                    action_space.remove(last_action)
                    return random.choice(action_space)
            return random.choice(np.arange(self.action_size))