def get_graph_manager()

in src/markov/sagemaker_graph_manager.py [0:0]


def get_graph_manager(hp_dict, agent_list, run_phase_subject, enable_domain_randomization=False,
                      done_condition=any, run_type=str(RunType.ROLLOUT_WORKER),
                      pause_physics=None, unpause_physics=None):
    ####################
    # Hyperparameters #
    ####################
    # Note: The following three line hard-coded to pick the first agent's trainig algorithm
    # and dump the hyper parameters for the particular training algorithm into json
    # for training jobs (so that the console display the training hyperparameters correctly)
    # since right now, we only support training one model at a time.
    # TODO: clean these lines up when we support multi-agent training.
    training_algorithm = agent_list[0].ctrl.model_metadata.training_algorithm if agent_list else None
    params = get_updated_hyper_parameters(hp_dict, training_algorithm)
    params_json = json.dumps(params, indent=2, sort_keys=True)
    print("Using the following hyper-parameters", params_json, sep='\n')

    ####################
    # Graph Scheduling #
    ####################
    schedule_params = ScheduleParameters()
    schedule_params.improve_steps = TrainingSteps(params[HyperParameterKeys.TERMINATION_CONDITION_MAX_EPISODES.value])
    schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(40)
    schedule_params.evaluation_steps = EnvironmentEpisodes(5)
    schedule_params.heatup_steps = EnvironmentSteps(0)

    #########
    # Agent #
    #########
    trainable_agents_list = list()
    non_trainable_agents_list = list()

    for agent in agent_list:
        if agent.network_settings:
            training_algorithm = agent.ctrl.model_metadata.training_algorithm
            params = get_updated_hyper_parameters(hp_dict, training_algorithm)
            if TrainingAlgorithm.SAC.value == training_algorithm:
                agent_params = get_sac_params(DeepRacerSACAgentParams(), agent, params, run_type)
            else:
                agent_params = get_clipped_ppo_params(DeepRacerClippedPPOAgentParams(), agent, params)
            agent_params.env_agent = agent
            input_filter = InputFilter(is_a_reference_filter=True)
            for observation in agent.network_settings['input_embedders'].keys():
                if observation == Input.LEFT_CAMERA.value or observation == Input.CAMERA.value or \
                        observation == Input.OBSERVATION.value:
                    input_filter.add_observation_filter(observation,
                                                        'to_grayscale', ObservationRGBToYFilter())
                    input_filter.add_observation_filter(observation,
                                                        'to_uint8',
                                                        ObservationToUInt8Filter(0, 255))
                    input_filter.add_observation_filter(observation,
                                                        'stacking', ObservationStackingFilter(1))

                if observation == Input.STEREO.value:
                    input_filter.add_observation_filter(observation,
                                                        'to_uint8',
                                                        ObservationToUInt8Filter(0, 255))

                if observation == Input.LIDAR.value:
                    input_filter.add_observation_filter(observation,
                                                        'clipping',
                                                        ObservationClippingFilter(0.15, 1.0))
                if observation == Input.SECTOR_LIDAR.value:
                    sector_binary_filter = ObservationSectorDiscretizeFilter(num_sectors=NUMBER_OF_LIDAR_SECTORS,
                                                                             num_values_per_sector=1,
                                                                             clipping_dist=SECTOR_LIDAR_CLIPPING_DIST)
                    input_filter.add_observation_filter(observation,
                                                        'binary',
                                                        sector_binary_filter)
                if observation == Input.DISCRETIZED_SECTOR_LIDAR.value:
                    num_sectors = agent.ctrl.model_metadata.lidar_num_sectors
                    num_values_per_sector = agent.ctrl.model_metadata.lidar_num_values_per_sector
                    clipping_dist = agent.ctrl.model_metadata.lidar_clipping_dist

                    sector_discretize_filter = ObservationSectorDiscretizeFilter(num_sectors=num_sectors,
                                                                                 num_values_per_sector=num_values_per_sector,
                                                                                 clipping_dist=clipping_dist)
                    input_filter.add_observation_filter(observation,
                                                        'discrete',
                                                        sector_discretize_filter)
            agent_params.input_filter = input_filter()
            trainable_agents_list.append(agent_params)
        else:
            non_trainable_agents_list.append(agent)

    ###############
    # Environment #
    ###############
    env_params = DeepRacerRacetrackEnvParameters()
    env_params.agents_params = trainable_agents_list
    env_params.non_trainable_agents = non_trainable_agents_list
    env_params.level = 'DeepRacerRacetrackEnv-v0'
    env_params.run_phase_subject = run_phase_subject
    env_params.enable_domain_randomization = enable_domain_randomization
    env_params.done_condition = done_condition
    env_params.pause_physics = pause_physics
    env_params.unpause_physics = unpause_physics
    vis_params = VisualizationParameters()
    vis_params.dump_mp4 = False

    ########
    # Test #
    ########
    preset_validation_params = PresetValidationParameters()
    preset_validation_params.test = True
    preset_validation_params.min_reward_threshold = 400
    preset_validation_params.max_episodes_to_achieve_reward = 10000

    graph_manager = MultiAgentGraphManager(agents_params=trainable_agents_list,
                                           env_params=env_params,
                                           schedule_params=schedule_params, vis_params=vis_params,
                                           preset_validation_params=preset_validation_params,
                                           done_condition=done_condition)
    return graph_manager, params_json