in src/markov/rollout_worker.py [0:0]
def rollout_worker(graph_manager, num_workers, rollout_idx, task_parameters, simtrace_video_s3_writers,
pause_physics, unpause_physics):
"""
wait for first checkpoint then perform rollouts using the model
"""
if not graph_manager.data_store:
raise AttributeError("None type for data_store object")
is_sageonly = utils.check_is_sageonly()
data_store = graph_manager.data_store
#TODO change agent to specific agent name for multip agent case
checkpoint_dir = os.path.join(task_parameters.checkpoint_restore_path, "agent")
graph_manager.data_store.wait_for_checkpoints()
graph_manager.data_store.wait_for_trainer_ready()
# wait for the required cancel services to become available
# Do this only for Robomaker job.
if not is_sageonly:
rospy.wait_for_service('/robomaker/job/cancel')
# Make the clients that will allow us to pause and unpause the physics
rospy.wait_for_service('/gazebo/pause_physics_dr')
rospy.wait_for_service('/gazebo/unpause_physics_dr')
rospy.wait_for_service('/racecar/save_mp4/subscribe_to_save_mp4')
rospy.wait_for_service('/racecar/save_mp4/unsubscribe_from_save_mp4')
subscribe_to_save_mp4 = ServiceProxyWrapper('/racecar/save_mp4/subscribe_to_save_mp4', Empty)
unsubscribe_from_save_mp4 = ServiceProxyWrapper('/racecar/save_mp4/unsubscribe_from_save_mp4', Empty)
graph_manager.create_graph(task_parameters=task_parameters, stop_physics=pause_physics,
start_physics=unpause_physics, empty_service_call=EmptyRequest)
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
last_checkpoint = chkpt_state_reader.get_latest().num
# this worker should play a fraction of the total playing steps per rollout
episode_steps_per_rollout = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps
act_steps = int(episode_steps_per_rollout / num_workers)
if rollout_idx < episode_steps_per_rollout % num_workers:
act_steps += 1
act_steps = EnvironmentEpisodes(act_steps)
configure_environment_randomizer()
for _ in range((graph_manager.improve_steps / act_steps.num_steps).num_steps):
# Collect profiler information only IS_PROFILER_ON is true
with utils.Profiler(s3_bucket=PROFILER_S3_BUCKET, s3_prefix=PROFILER_S3_PREFIX,
output_local_path=ROLLOUT_WORKER_PROFILER_PATH, enable_profiling=IS_PROFILER_ON):
graph_manager.phase = RunPhase.TRAIN
exit_if_trainer_done(checkpoint_dir, simtrace_video_s3_writers, rollout_idx)
unpause_physics(EmptyRequest())
graph_manager.reset_internal_state(True)
graph_manager.act(act_steps, wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes)
graph_manager.reset_internal_state(True)
time.sleep(1)
pause_physics(EmptyRequest())
graph_manager.phase = RunPhase.UNDEFINED
new_checkpoint = -1
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type\
== DistributedCoachSynchronizationType.SYNC:
unpause_physics(EmptyRequest())
is_save_mp4_enabled = rospy.get_param('MP4_S3_BUCKET', None) and rollout_idx == 0
if is_save_mp4_enabled:
subscribe_to_save_mp4(EmptyRequest())
if rollout_idx == 0:
for _ in range(MIN_EVAL_TRIALS):
graph_manager.evaluate(EnvironmentSteps(1))
# For sageonly job for better performance only run limited number of evaluations.
# Pausing the physics makes its performance same as RoboMaker + SageMaker
if is_sageonly:
if is_save_mp4_enabled:
unsubscribe_from_save_mp4(EmptyRequest())
# upload simtrace and mp4 into s3 bucket
for s3_writer in simtrace_video_s3_writers:
s3_writer.persist(utils.get_s3_kms_extra_args())
graph_manager.phase = RunPhase.WAITING
pause_physics(EmptyRequest())
while new_checkpoint < last_checkpoint + 1:
exit_if_trainer_done(checkpoint_dir, simtrace_video_s3_writers, rollout_idx)
# Continously run the evaluation only for SageMaker + RoboMaker job
if not is_sageonly and rollout_idx == 0:
graph_manager.evaluate(EnvironmentSteps(1))
new_checkpoint = data_store.get_coach_checkpoint_number('agent')
# Save the mp4 for Robo+Sage jobs
if not is_sageonly:
if is_save_mp4_enabled:
unsubscribe_from_save_mp4(EmptyRequest())
# upload simtrace and mp4 into s3 bucket
for s3_writer in simtrace_video_s3_writers:
s3_writer.persist(utils.get_s3_kms_extra_args())
pause_physics(EmptyRequest())
data_store.load_from_store(expected_checkpoint_number=last_checkpoint + 1)
graph_manager.restore_checkpoint()
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type\
== DistributedCoachSynchronizationType.ASYNC:
if new_checkpoint > last_checkpoint:
graph_manager.restore_checkpoint()
last_checkpoint = new_checkpoint