in source/MXNetEnv/training/training_src/networks/agent.py [0:0]
def act(self, state, snake_id, turn_count, snake_health, episode, eps=0.):
"""Returns actions for given state as per current policy.
Params
======
state (array_like): current state
snake_id (int): ID of the snake
turn_count (int): turn count of the game
snake_health (int): health of the snake
episode (int): The current episode, used for checking previous states
eps (float): epsilon, for epsilon-greedy action selection
"""
# Epsilon-greedy action selection
if random.random() > eps:
snake_health = snake_health - 1 # Account for taking the current move
empty_state = np.zeros(state.shape)
turn_count_eos = -1
snake_health_eos = 101
with autograd.predict_mode():
last_n_memory = self.memory.get_last_n(n=self.sequence_length - 1)
state_sequence, snake_id_sequence = [], []
turn_count_sequence, snake_health_sequence = [], []
for i in range(self.sequence_length):
if i == self.sequence_length - 1:
turn_count_i = turn_count
episode_i = episode
delta = 0
state_i = state
snake_health_i = snake_health
else:
turn_count_i = last_n_memory[i].turn_count
episode_i = last_n_memory[i].episode
delta = self.sequence_length - 1 - i
state_i = last_n_memory[i].state
snake_health_i = last_n_memory[i].snake_health
episode_correct = episode_i == episode
turn_correct = turn_count_i + delta == turn_count
if episode_correct and turn_correct:
state_sequence.append(state_i)
turn_count_sequence.append(turn_count_i)
snake_health_sequence.append(snake_health_i)
else:
state_sequence.append(empty_state)
turn_count_sequence.append(turn_count_eos)
snake_health_sequence.append(snake_health_eos)
state_sequence = mx.nd.array(np.stack(state_sequence),
ctx=ctx).transpose((0, 3, 1, 2)).expand_dims(0)
turn_count_sequence = mx.nd.array(np.stack(turn_count_sequence),
ctx=ctx).expand_dims(0)
snake_health_sequence = mx.nd.array(np.stack(snake_health_sequence),
ctx=ctx).expand_dims(0)
snake_id_sequence = mx.nd.array(np.array([snake_id]*self.sequence_length),
ctx=ctx).expand_dims(0)
if self.qnetwork_local.take_additional_forward_arguments:
action_values = self.qnetwork_local(state_sequence,
snake_id_sequence,
turn_count_sequence,
snake_health_sequence)
else:
action_values = self.qnetwork_local(state_sequence)
return np.argmax(action_values.asnumpy())
else:
last_memory = self.memory.get_last_n(n=1)[0]
if last_memory is not None:
# Disable choosing random actions of forbiden moves
last_action = last_memory.action
last_episode = last_memory.episode
last_turn_count = last_memory.turn_count
if last_episode == episode and (last_turn_count == turn_count - 1):
# Check that last_memory is from the same episode
action_space = [*range(self.action_size)]
action_space.remove(last_action)
return random.choice(action_space)
return random.choice(np.arange(self.action_size))