in src/markov/training_worker.py [0:0]
def training_worker(graph_manager, task_parameters, user_batch_size,
user_episode_per_rollout, training_algorithm):
try:
# initialize graph
graph_manager.create_graph(task_parameters)
# save initial checkpoint
graph_manager.save_checkpoint()
# training loop
steps = 0
graph_manager.setup_memory_backend()
graph_manager.signal_ready()
# To handle SIGTERM
door_man = utils.DoorMan()
while steps < graph_manager.improve_steps.num_steps:
# Collect profiler information only IS_PROFILER_ON is true
with utils.Profiler(s3_bucket=PROFILER_S3_BUCKET, s3_prefix=PROFILER_S3_PREFIX,
output_local_path=TRAINING_WORKER_PROFILER_PATH, enable_profiling=IS_PROFILER_ON):
graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
graph_manager.phase = core_types.RunPhase.UNDEFINED
episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout()
for level in graph_manager.level_managers:
for agent in level.agents.values():
agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout
# TODO: Refactor the flow to remove conditional checks for specific algorithms
# ------------------------sac only---------------------------------------------
if training_algorithm == TrainingAlgorithm.SAC.value:
rollout_steps = graph_manager.memory_backend.get_rollout_steps()
# NOTE: you can train even more iterations than rollout_steps by increasing the number below for SAC
agent.ap.algorithm.num_consecutive_training_steps = list(rollout_steps.values())[
0] # rollout_steps[agent]
# -------------------------------------------------------------------------------
if graph_manager.should_train():
# Make sure we have enough data for the requested batches
rollout_steps = graph_manager.memory_backend.get_rollout_steps()
if any(rollout_steps.values()) <= 0:
log_and_exit("No rollout data retrieved from the rollout worker",
SIMAPP_TRAINING_WORKER_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500)
# TODO: Refactor the flow to remove conditional checks for specific algorithms
# DH: for SAC, check if experience replay memory has enough transitions
if training_algorithm == TrainingAlgorithm.SAC.value:
replay_mem_size = min([agent.memory.num_transitions()
for level in graph_manager.level_managers
for agent in level.agents.values()])
episode_batch_size = user_batch_size if replay_mem_size > user_batch_size \
else 2**math.floor(math.log(min(rollout_steps.values()), 2))
else:
episode_batch_size = user_batch_size if min(rollout_steps.values()) > user_batch_size else 2**math.floor(math.log(min(rollout_steps.values()), 2))
# Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
# as batch size less than 2 causes the batch list to become a scalar which causes an exception
for level in graph_manager.level_managers:
for agent in level.agents.values():
for net_key in agent.ap.network_wrappers:
agent.ap.network_wrappers[net_key].batch_size = episode_batch_size
steps += 1
graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.train()
graph_manager.phase = core_types.RunPhase.UNDEFINED
# Check for Nan's in all agents
rollout_has_nan = False
for level in graph_manager.level_managers:
for agent in level.agents.values():
if np.isnan(agent.loss.get_mean()):
rollout_has_nan = True
if rollout_has_nan:
log_and_exit("NaN detected in loss function, aborting training.",
SIMAPP_TRAINING_WORKER_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500)
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
graph_manager.save_checkpoint()
else:
graph_manager.occasionally_save_checkpoint()
# Clear any data stored in signals that is no longer necessary
graph_manager.reset_internal_state()
for level in graph_manager.level_managers:
for agent in level.agents.values():
agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout
if door_man.terminate_now:
log_and_exit("Received SIGTERM. Checkpointing before exiting.",
SIMAPP_TRAINING_WORKER_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500)
graph_manager.save_checkpoint()
break
except ValueError as err:
if utils.is_user_error(err):
log_and_exit("User modified model: {}".format(err),
SIMAPP_TRAINING_WORKER_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500)
else:
log_and_exit("An error occured while training: {}".format(err),
SIMAPP_TRAINING_WORKER_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500)
except Exception as ex:
log_and_exit("An error occured while training: {}".format(ex),
SIMAPP_TRAINING_WORKER_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500)
finally:
graph_manager.data_store.upload_finished_file()