in src/markov/deepracer_memory.py [0:0]
def fetch(self, num_consecutive_playing_steps=None):
''' Retrieves the data from the rollout worker
num_consecutive_playing_steps - Struct containing the number of episodes to
collect before performing a training iteration
'''
episode_counter = 0
step_counter = 0
self.episode_req = 0
# Clear any left over Episodes data in queue from previous fetch
[agent_queue.get() for agent_queue in self.data_queues.values() if not agent_queue.empty()]
self.request_data = True
[event.set() for event in self.request_events.values()]
self.rollout_steps = dict.fromkeys(self.rollout_steps, 0)
self.total_episodes_in_rollout = 0
while episode_counter <= num_consecutive_playing_steps.num_steps:
try:
objs = {k: v.get() for k, v in self.data_queues.items()}
if all(obj[0] == episode_counter and isinstance(obj[1], Episode) for obj in objs.values()):
step_counter += sum(obj[1].length() for obj in objs.values())
if step_counter <= self.max_step:
self.rollout_steps = {k: self.rollout_steps[k] + objs[k][1].length() for k in self.rollout_steps.keys()}
self.total_episodes_in_rollout += 1
transition_iters = {k: iter(v[1].transitions) for k, v in objs.items()}
transition = {k: next(v, None) for k, v in transition_iters.items()}
while any(transition.values()):
yield transition
transition = {k: next(v, None) for k, v in transition_iters.items()}
elif episode_counter != num_consecutive_playing_steps.num_steps - 1:
# If step_counter goes over self.max_step, then directly request
# last episode (index of last episode: num_consecutive_playing_steps.num - 1).
# If we just increment the episode one by one till the last one, then it will basically fill up
# Redis memory that resides in training worker.
# When rollout worker actually returns last episode, then we safely increment episode_counter
# to num_consecutive_playing_steps.num, so both rollout worker and training worker can finish
# the epoch gracefully.
episode_counter = num_consecutive_playing_steps.num_steps - 1
self.episode_req = episode_counter
continue
episode_counter += 1
self.episode_req = episode_counter
# When we request num_consecutive_playing_steps.num we will get back
# 1 more than the requested index this lets us know the rollout worker
# has given us all available data
elif all(obj[0] == num_consecutive_playing_steps.num_steps + 1 for obj in objs.values()):
episode_counter = num_consecutive_playing_steps.num_steps + 1
self.episode_req = 0
self.request_data = False
continue
[event.set() for event in self.request_events.values()]
except Exception as ex:
LOG.info("Trainer fetch error: %s", ex)
continue