in reinforcement_learning/rl_deepracer_robomaker_coach_gazebo/src/markov/evaluation_worker.py [0:0]
def main():
"""Main function for evaluation worker"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-p",
"--preset",
help="(string) Name of a preset to run \
(class name from the 'presets' directory.)",
type=str,
required=False,
)
parser.add_argument(
"--s3_bucket",
help="list(string) S3 bucket",
type=str,
nargs="+",
default=rospy.get_param("MODEL_S3_BUCKET", ["gsaur-test"]),
)
parser.add_argument(
"--s3_prefix",
help="list(string) S3 prefix",
type=str,
nargs="+",
default=rospy.get_param("MODEL_S3_PREFIX", ["sagemaker"]),
)
parser.add_argument(
"--aws_region",
help="(string) AWS region",
type=str,
default=rospy.get_param("AWS_REGION", "us-east-1"),
)
parser.add_argument(
"--number_of_trials",
help="(integer) Number of trials",
type=int,
default=int(rospy.get_param("NUMBER_OF_TRIALS", 10)),
)
parser.add_argument(
"-c",
"--local_model_directory",
help="(string) Path to a folder containing a checkpoint \
to restore the model from.",
type=str,
default="./checkpoint",
)
parser.add_argument(
"--number_of_resets",
help="(integer) Number of resets",
type=int,
default=int(rospy.get_param("NUMBER_OF_RESETS", 0)),
)
parser.add_argument(
"--penalty_seconds",
help="(float) penalty second",
type=float,
default=float(rospy.get_param("PENALTY_SECONDS", 2.0)),
)
parser.add_argument(
"--job_type",
help="(string) job type",
type=str,
default=rospy.get_param("JOB_TYPE", "EVALUATION"),
)
parser.add_argument(
"--is_continuous",
help="(boolean) is continous after lap completion",
type=bool,
default=utils.str2bool(rospy.get_param("IS_CONTINUOUS", False)),
)
parser.add_argument(
"--race_type",
help="(string) Race type",
type=str,
default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"),
)
parser.add_argument(
"--off_track_penalty",
help="(float) off track penalty second",
type=float,
default=float(rospy.get_param("OFF_TRACK_PENALTY", 2.0)),
)
parser.add_argument(
"--collision_penalty",
help="(float) collision penalty second",
type=float,
default=float(rospy.get_param("COLLISION_PENALTY", 5.0)),
)
args = parser.parse_args()
arg_s3_bucket = args.s3_bucket
arg_s3_prefix = args.s3_prefix
logger.info("S3 bucket: %s \n S3 prefix: %s", arg_s3_bucket, arg_s3_prefix)
metrics_s3_buckets = rospy.get_param("METRICS_S3_BUCKET")
metrics_s3_object_keys = rospy.get_param("METRICS_S3_OBJECT_KEY")
arg_s3_bucket, arg_s3_prefix = utils.force_list(arg_s3_bucket), utils.force_list(arg_s3_prefix)
metrics_s3_buckets = utils.force_list(metrics_s3_buckets)
metrics_s3_object_keys = utils.force_list(metrics_s3_object_keys)
validate_list = [arg_s3_bucket, arg_s3_prefix, metrics_s3_buckets, metrics_s3_object_keys]
simtrace_s3_bucket = rospy.get_param("SIMTRACE_S3_BUCKET", None)
mp4_s3_bucket = rospy.get_param("MP4_S3_BUCKET", None)
if simtrace_s3_bucket:
simtrace_s3_object_prefix = rospy.get_param("SIMTRACE_S3_PREFIX")
simtrace_s3_bucket = utils.force_list(simtrace_s3_bucket)
simtrace_s3_object_prefix = utils.force_list(simtrace_s3_object_prefix)
validate_list.extend([simtrace_s3_bucket, simtrace_s3_object_prefix])
if mp4_s3_bucket:
mp4_s3_object_prefix = rospy.get_param("MP4_S3_OBJECT_PREFIX")
mp4_s3_bucket = utils.force_list(mp4_s3_bucket)
mp4_s3_object_prefix = utils.force_list(mp4_s3_object_prefix)
validate_list.extend([mp4_s3_bucket, mp4_s3_object_prefix])
if not all([lambda x: len(x) == len(validate_list[0]), validate_list]):
log_and_exit(
"Eval worker error: Incorrect arguments passed: {}".format(validate_list),
SIMAPP_SIMULATION_WORKER_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500,
)
if args.number_of_resets != 0 and args.number_of_resets < MIN_RESET_COUNT:
raise GenericRolloutException("number of resets is less than {}".format(MIN_RESET_COUNT))
# Instantiate Cameras
if len(arg_s3_bucket) == 1:
configure_camera(namespaces=["racecar"])
else:
configure_camera(
namespaces=[
"racecar_{}".format(str(agent_index)) for agent_index in range(len(arg_s3_bucket))
]
)
agent_list = list()
s3_bucket_dict = dict()
s3_prefix_dict = dict()
checkpoint_dict = dict()
simtrace_video_s3_writers = []
start_positions = get_start_positions(len(arg_s3_bucket))
done_condition = utils.str_to_done_condition(rospy.get_param("DONE_CONDITION", any))
park_positions = utils.pos_2d_str_to_list(rospy.get_param("PARK_POSITIONS", []))
# if not pass in park positions for all done condition case, use default
if not park_positions:
park_positions = [DEFAULT_PARK_POSITION for _ in arg_s3_bucket]
for agent_index, _ in enumerate(arg_s3_bucket):
agent_name = "agent" if len(arg_s3_bucket) == 1 else "agent_{}".format(str(agent_index))
racecar_name = (
"racecar" if len(arg_s3_bucket) == 1 else "racecar_{}".format(str(agent_index))
)
s3_bucket_dict[agent_name] = arg_s3_bucket[agent_index]
s3_prefix_dict[agent_name] = arg_s3_prefix[agent_index]
# download model metadata
model_metadata = ModelMetadata(
bucket=arg_s3_bucket[agent_index],
s3_key=get_s3_key(arg_s3_prefix[agent_index], MODEL_METADATA_S3_POSTFIX),
region_name=args.aws_region,
local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format(agent_name),
)
model_metadata_info = model_metadata.get_model_metadata_info()
version = model_metadata_info[ModelMetadataKeys.VERSION.value]
# checkpoint s3 instance
checkpoint = Checkpoint(
bucket=arg_s3_bucket[agent_index],
s3_prefix=arg_s3_prefix[agent_index],
region_name=args.aws_region,
agent_name=agent_name,
checkpoint_dir=args.local_model_directory,
)
# make coach checkpoint compatible
if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible():
checkpoint.rl_coach_checkpoint.make_compatible(checkpoint.syncfile_ready)
# get best model checkpoint string
model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint()
# Select the best checkpoint model by uploading rl coach .coach_checkpoint file
checkpoint.rl_coach_checkpoint.update(
model_checkpoint_name=model_checkpoint_name,
s3_kms_extra_args=utils.get_s3_kms_extra_args(),
)
checkpoint_dict[agent_name] = checkpoint
agent_config = {
"model_metadata": model_metadata,
ConfigParams.CAR_CTRL_CONFIG.value: {
ConfigParams.LINK_NAME_LIST.value: [
link_name.replace("racecar", racecar_name) for link_name in LINK_NAMES
],
ConfigParams.VELOCITY_LIST.value: [
velocity_topic.replace("racecar", racecar_name)
for velocity_topic in VELOCITY_TOPICS
],
ConfigParams.STEERING_LIST.value: [
steering_topic.replace("racecar", racecar_name)
for steering_topic in STEERING_TOPICS
],
ConfigParams.CHANGE_START.value: utils.str2bool(
rospy.get_param("CHANGE_START_POSITION", False)
),
ConfigParams.ALT_DIR.value: utils.str2bool(
rospy.get_param("ALTERNATE_DRIVING_DIRECTION", False)
),
ConfigParams.MODEL_METADATA.value: model_metadata,
ConfigParams.REWARD.value: reward_function,
ConfigParams.AGENT_NAME.value: racecar_name,
ConfigParams.VERSION.value: version,
ConfigParams.NUMBER_OF_RESETS.value: args.number_of_resets,
ConfigParams.PENALTY_SECONDS.value: args.penalty_seconds,
ConfigParams.NUMBER_OF_TRIALS.value: args.number_of_trials,
ConfigParams.IS_CONTINUOUS.value: args.is_continuous,
ConfigParams.RACE_TYPE.value: args.race_type,
ConfigParams.COLLISION_PENALTY.value: args.collision_penalty,
ConfigParams.OFF_TRACK_PENALTY.value: args.off_track_penalty,
ConfigParams.START_POSITION.value: start_positions[agent_index],
ConfigParams.DONE_CONDITION.value: done_condition,
},
}
metrics_s3_config = {
MetricsS3Keys.METRICS_BUCKET.value: metrics_s3_buckets[agent_index],
MetricsS3Keys.METRICS_KEY.value: metrics_s3_object_keys[agent_index],
# Replaced rospy.get_param('AWS_REGION') to be equal to the argument being passed
# or default argument set
MetricsS3Keys.REGION.value: args.aws_region,
}
aws_region = rospy.get_param("AWS_REGION", args.aws_region)
if simtrace_s3_bucket:
simtrace_video_s3_writers.append(
SimtraceVideo(
upload_type=SimtraceVideoNames.SIMTRACE_EVAL.value,
bucket=simtrace_s3_bucket[agent_index],
s3_prefix=simtrace_s3_object_prefix[agent_index],
region_name=aws_region,
local_path=SIMTRACE_EVAL_LOCAL_PATH_FORMAT.format(agent_name),
)
)
if mp4_s3_bucket:
simtrace_video_s3_writers.extend(
[
SimtraceVideo(
upload_type=SimtraceVideoNames.PIP.value,
bucket=mp4_s3_bucket[agent_index],
s3_prefix=mp4_s3_object_prefix[agent_index],
region_name=aws_region,
local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format(agent_name),
),
SimtraceVideo(
upload_type=SimtraceVideoNames.DEGREE45.value,
bucket=mp4_s3_bucket[agent_index],
s3_prefix=mp4_s3_object_prefix[agent_index],
region_name=aws_region,
local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format(agent_name),
),
SimtraceVideo(
upload_type=SimtraceVideoNames.TOPVIEW.value,
bucket=mp4_s3_bucket[agent_index],
s3_prefix=mp4_s3_object_prefix[agent_index],
region_name=aws_region,
local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format(agent_name),
),
]
)
run_phase_subject = RunPhaseSubject()
agent_list.append(
create_rollout_agent(
agent_config,
EvalMetrics(agent_name, metrics_s3_config, args.is_continuous),
run_phase_subject,
)
)
agent_list.append(create_obstacles_agent())
agent_list.append(create_bot_cars_agent())
# ROS service to indicate all the robomaker markov packages are ready for consumption
signal_robomaker_markov_package_ready()
PhaseObserver("/agent/training_phase", run_phase_subject)
enable_domain_randomization = utils.str2bool(
rospy.get_param("ENABLE_DOMAIN_RANDOMIZATION", False)
)
sm_hyperparams_dict = {}
# Make the clients that will allow us to pause and unpause the physics
rospy.wait_for_service("/gazebo/pause_physics_dr")
rospy.wait_for_service("/gazebo/unpause_physics_dr")
pause_physics = ServiceProxyWrapper("/gazebo/pause_physics_dr", Empty)
unpause_physics = ServiceProxyWrapper("/gazebo/unpause_physics_dr", Empty)
graph_manager, _ = get_graph_manager(
hp_dict=sm_hyperparams_dict,
agent_list=agent_list,
run_phase_subject=run_phase_subject,
enable_domain_randomization=enable_domain_randomization,
done_condition=done_condition,
pause_physics=pause_physics,
unpause_physics=unpause_physics,
)
ds_params_instance = S3BotoDataStoreParameters(checkpoint_dict=checkpoint_dict)
graph_manager.data_store = S3BotoDataStore(
params=ds_params_instance, graph_manager=graph_manager, ignore_lock=True
)
graph_manager.env_params.seed = 0
task_parameters = TaskParameters()
task_parameters.checkpoint_restore_path = args.local_model_directory
evaluation_worker(
graph_manager=graph_manager,
number_of_trials=args.number_of_trials,
task_parameters=task_parameters,
simtrace_video_s3_writers=simtrace_video_s3_writers,
is_continuous=args.is_continuous,
park_positions=park_positions,
race_type=args.race_type,
pause_physics=pause_physics,
unpause_physics=unpause_physics,
)