in src/markov/sagemaker_graph_manager.py [0:0]
def get_graph_manager(hp_dict, agent_list, run_phase_subject, enable_domain_randomization=False,
done_condition=any, run_type=str(RunType.ROLLOUT_WORKER),
pause_physics=None, unpause_physics=None):
####################
# Hyperparameters #
####################
# Note: The following three line hard-coded to pick the first agent's trainig algorithm
# and dump the hyper parameters for the particular training algorithm into json
# for training jobs (so that the console display the training hyperparameters correctly)
# since right now, we only support training one model at a time.
# TODO: clean these lines up when we support multi-agent training.
training_algorithm = agent_list[0].ctrl.model_metadata.training_algorithm if agent_list else None
params = get_updated_hyper_parameters(hp_dict, training_algorithm)
params_json = json.dumps(params, indent=2, sort_keys=True)
print("Using the following hyper-parameters", params_json, sep='\n')
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(params[HyperParameterKeys.TERMINATION_CONDITION_MAX_EPISODES.value])
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(40)
schedule_params.evaluation_steps = EnvironmentEpisodes(5)
schedule_params.heatup_steps = EnvironmentSteps(0)
#########
# Agent #
#########
trainable_agents_list = list()
non_trainable_agents_list = list()
for agent in agent_list:
if agent.network_settings:
training_algorithm = agent.ctrl.model_metadata.training_algorithm
params = get_updated_hyper_parameters(hp_dict, training_algorithm)
if TrainingAlgorithm.SAC.value == training_algorithm:
agent_params = get_sac_params(DeepRacerSACAgentParams(), agent, params, run_type)
else:
agent_params = get_clipped_ppo_params(DeepRacerClippedPPOAgentParams(), agent, params)
agent_params.env_agent = agent
input_filter = InputFilter(is_a_reference_filter=True)
for observation in agent.network_settings['input_embedders'].keys():
if observation == Input.LEFT_CAMERA.value or observation == Input.CAMERA.value or \
observation == Input.OBSERVATION.value:
input_filter.add_observation_filter(observation,
'to_grayscale', ObservationRGBToYFilter())
input_filter.add_observation_filter(observation,
'to_uint8',
ObservationToUInt8Filter(0, 255))
input_filter.add_observation_filter(observation,
'stacking', ObservationStackingFilter(1))
if observation == Input.STEREO.value:
input_filter.add_observation_filter(observation,
'to_uint8',
ObservationToUInt8Filter(0, 255))
if observation == Input.LIDAR.value:
input_filter.add_observation_filter(observation,
'clipping',
ObservationClippingFilter(0.15, 1.0))
if observation == Input.SECTOR_LIDAR.value:
sector_binary_filter = ObservationSectorDiscretizeFilter(num_sectors=NUMBER_OF_LIDAR_SECTORS,
num_values_per_sector=1,
clipping_dist=SECTOR_LIDAR_CLIPPING_DIST)
input_filter.add_observation_filter(observation,
'binary',
sector_binary_filter)
if observation == Input.DISCRETIZED_SECTOR_LIDAR.value:
num_sectors = agent.ctrl.model_metadata.lidar_num_sectors
num_values_per_sector = agent.ctrl.model_metadata.lidar_num_values_per_sector
clipping_dist = agent.ctrl.model_metadata.lidar_clipping_dist
sector_discretize_filter = ObservationSectorDiscretizeFilter(num_sectors=num_sectors,
num_values_per_sector=num_values_per_sector,
clipping_dist=clipping_dist)
input_filter.add_observation_filter(observation,
'discrete',
sector_discretize_filter)
agent_params.input_filter = input_filter()
trainable_agents_list.append(agent_params)
else:
non_trainable_agents_list.append(agent)
###############
# Environment #
###############
env_params = DeepRacerRacetrackEnvParameters()
env_params.agents_params = trainable_agents_list
env_params.non_trainable_agents = non_trainable_agents_list
env_params.level = 'DeepRacerRacetrackEnv-v0'
env_params.run_phase_subject = run_phase_subject
env_params.enable_domain_randomization = enable_domain_randomization
env_params.done_condition = done_condition
env_params.pause_physics = pause_physics
env_params.unpause_physics = unpause_physics
vis_params = VisualizationParameters()
vis_params.dump_mp4 = False
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 400
preset_validation_params.max_episodes_to_achieve_reward = 10000
graph_manager = MultiAgentGraphManager(agents_params=trainable_agents_list,
env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params,
preset_validation_params=preset_validation_params,
done_condition=done_condition)
return graph_manager, params_json