def training_worker()

in src/markov/training_worker.py [0:0]


def training_worker(graph_manager, task_parameters, user_batch_size,
                    user_episode_per_rollout, training_algorithm):
    try:
        # initialize graph
        graph_manager.create_graph(task_parameters)

        # save initial checkpoint
        graph_manager.save_checkpoint()

        # training loop
        steps = 0

        graph_manager.setup_memory_backend()
        graph_manager.signal_ready()

        # To handle SIGTERM
        door_man = utils.DoorMan()

        while steps < graph_manager.improve_steps.num_steps:
             # Collect profiler information only IS_PROFILER_ON is true
            with utils.Profiler(s3_bucket=PROFILER_S3_BUCKET, s3_prefix=PROFILER_S3_PREFIX,
                                output_local_path=TRAINING_WORKER_PROFILER_PATH, enable_profiling=IS_PROFILER_ON):
                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout()

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout

                # TODO: Refactor the flow to remove conditional checks for specific algorithms
                # ------------------------sac only---------------------------------------------
                if training_algorithm == TrainingAlgorithm.SAC.value:
                    rollout_steps = graph_manager.memory_backend.get_rollout_steps()

                    # NOTE: you can train even more iterations than rollout_steps by increasing the number below for SAC
                    agent.ap.algorithm.num_consecutive_training_steps = list(rollout_steps.values())[
                        0]  # rollout_steps[agent]
                # -------------------------------------------------------------------------------
                if graph_manager.should_train():
                    # Make sure we have enough data for the requested batches
                    rollout_steps = graph_manager.memory_backend.get_rollout_steps()
                    if any(rollout_steps.values()) <= 0:
                        log_and_exit("No rollout data retrieved from the rollout worker",
                                     SIMAPP_TRAINING_WORKER_EXCEPTION,
                                     SIMAPP_EVENT_ERROR_CODE_500)

                    # TODO: Refactor the flow to remove conditional checks for specific algorithms
                    # DH: for SAC, check if experience replay memory has enough transitions
                    if training_algorithm == TrainingAlgorithm.SAC.value:
                        replay_mem_size = min([agent.memory.num_transitions()
                                               for level in graph_manager.level_managers
                                               for agent in level.agents.values()])
                        episode_batch_size = user_batch_size if replay_mem_size > user_batch_size \
                            else 2**math.floor(math.log(min(rollout_steps.values()), 2))
                    else:
                        episode_batch_size = user_batch_size if min(rollout_steps.values()) > user_batch_size else 2**math.floor(math.log(min(rollout_steps.values()), 2))
                    # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
                    # as  batch size less than 2 causes the batch list to become a scalar which causes an exception
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            for net_key in agent.ap.network_wrappers:
                                agent.ap.network_wrappers[net_key].batch_size = episode_batch_size

                    steps += 1

                    graph_manager.phase = core_types.RunPhase.TRAIN
                    graph_manager.train()
                    graph_manager.phase = core_types.RunPhase.UNDEFINED

                    # Check for Nan's in all agents
                    rollout_has_nan = False
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            if np.isnan(agent.loss.get_mean()):
                                rollout_has_nan = True
                    if rollout_has_nan:
                        log_and_exit("NaN detected in loss function, aborting training.",
                                     SIMAPP_TRAINING_WORKER_EXCEPTION,
                                     SIMAPP_EVENT_ERROR_CODE_500)

                    if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                        graph_manager.save_checkpoint()
                    else:
                        graph_manager.occasionally_save_checkpoint()

                    # Clear any data stored in signals that is no longer necessary
                    graph_manager.reset_internal_state()

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout

                if door_man.terminate_now:
                    log_and_exit("Received SIGTERM. Checkpointing before exiting.",
                                 SIMAPP_TRAINING_WORKER_EXCEPTION,
                                 SIMAPP_EVENT_ERROR_CODE_500)
                    graph_manager.save_checkpoint()
                    break

    except ValueError as err:
        if utils.is_user_error(err):
            log_and_exit("User modified model: {}".format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
        else:
            log_and_exit("An error occured while training: {}".format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        log_and_exit("An error occured while training: {}".format(ex),
                     SIMAPP_TRAINING_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)
    finally:
        graph_manager.data_store.upload_finished_file()