in src/markov/evaluation_worker.py [0:0]
def main():
""" Main function for evaluation worker """
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--preset',
help="(string) Name of a preset to run \
(class name from the 'presets' directory.)",
type=str,
required=False)
parser.add_argument('--s3_bucket',
help='list(string) S3 bucket',
type=str,
nargs='+',
default=rospy.get_param("MODEL_S3_BUCKET", ["gsaur-test"]))
parser.add_argument('--s3_prefix',
help='list(string) S3 prefix',
type=str,
nargs='+',
default=rospy.get_param("MODEL_S3_PREFIX", ["sagemaker"]))
parser.add_argument('--aws_region',
help='(string) AWS region',
type=str,
default=rospy.get_param("AWS_REGION", "us-east-1"))
parser.add_argument('--number_of_trials',
help='(integer) Number of trials',
type=int,
default=int(rospy.get_param("NUMBER_OF_TRIALS", 10)))
parser.add_argument('-c', '--local_model_directory',
help='(string) Path to a folder containing a checkpoint \
to restore the model from.',
type=str,
default='./checkpoint')
parser.add_argument('--number_of_resets',
help='(integer) Number of resets',
type=int,
default=int(rospy.get_param("NUMBER_OF_RESETS", 0)))
parser.add_argument('--penalty_seconds',
help='(float) penalty second',
type=float,
default=float(rospy.get_param("PENALTY_SECONDS", 2.0)))
parser.add_argument('--job_type',
help='(string) job type',
type=str,
default=rospy.get_param("JOB_TYPE", "EVALUATION"))
parser.add_argument('--is_continuous',
help='(boolean) is continous after lap completion',
type=bool,
default=utils.str2bool(rospy.get_param("IS_CONTINUOUS", False)))
parser.add_argument('--race_type',
help='(string) Race type',
type=str,
default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"))
parser.add_argument('--off_track_penalty',
help='(float) off track penalty second',
type=float,
default=float(rospy.get_param("OFF_TRACK_PENALTY", 2.0)))
parser.add_argument('--collision_penalty',
help='(float) collision penalty second',
type=float,
default=float(rospy.get_param("COLLISION_PENALTY", 5.0)))
args = parser.parse_args()
arg_s3_bucket = args.s3_bucket
arg_s3_prefix = args.s3_prefix
logger.info("S3 bucket: %s \n S3 prefix: %s", arg_s3_bucket, arg_s3_prefix)
metrics_s3_buckets = rospy.get_param('METRICS_S3_BUCKET')
metrics_s3_object_keys = rospy.get_param('METRICS_S3_OBJECT_KEY')
arg_s3_bucket, arg_s3_prefix = utils.force_list(arg_s3_bucket), utils.force_list(arg_s3_prefix)
metrics_s3_buckets = utils.force_list(metrics_s3_buckets)
metrics_s3_object_keys = utils.force_list(metrics_s3_object_keys)
validate_list = [arg_s3_bucket, arg_s3_prefix, metrics_s3_buckets, metrics_s3_object_keys]
simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None)
mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET', None)
if simtrace_s3_bucket:
simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX')
simtrace_s3_bucket = utils.force_list(simtrace_s3_bucket)
simtrace_s3_object_prefix = utils.force_list(simtrace_s3_object_prefix)
validate_list.extend([simtrace_s3_bucket, simtrace_s3_object_prefix])
if mp4_s3_bucket:
mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX')
mp4_s3_bucket = utils.force_list(mp4_s3_bucket)
mp4_s3_object_prefix = utils.force_list(mp4_s3_object_prefix)
validate_list.extend([mp4_s3_bucket, mp4_s3_object_prefix])
if not all([lambda x: len(x) == len(validate_list[0]), validate_list]):
log_and_exit("Eval worker error: Incorrect arguments passed: {}"
.format(validate_list),
SIMAPP_SIMULATION_WORKER_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500)
if args.number_of_resets != 0 and args.number_of_resets < MIN_RESET_COUNT:
raise GenericRolloutException("number of resets is less than {}".format(MIN_RESET_COUNT))
# Instantiate Cameras
if len(arg_s3_bucket) == 1:
configure_camera(namespaces=['racecar'])
else:
configure_camera(namespaces=[
'racecar_{}'.format(str(agent_index)) for agent_index in range(len(arg_s3_bucket))])
agent_list = list()
s3_bucket_dict = dict()
s3_prefix_dict = dict()
checkpoint_dict = dict()
simtrace_video_s3_writers = []
start_positions = get_start_positions(len(arg_s3_bucket))
done_condition = utils.str_to_done_condition(rospy.get_param("DONE_CONDITION", any))
park_positions = utils.pos_2d_str_to_list(rospy.get_param("PARK_POSITIONS", []))
# if not pass in park positions for all done condition case, use default
if not park_positions:
park_positions = [DEFAULT_PARK_POSITION for _ in arg_s3_bucket]
for agent_index, _ in enumerate(arg_s3_bucket):
agent_name = 'agent' if len(arg_s3_bucket) == 1 else 'agent_{}'.format(str(agent_index))
racecar_name = 'racecar' if len(arg_s3_bucket) == 1 else 'racecar_{}'.format(str(agent_index))
s3_bucket_dict[agent_name] = arg_s3_bucket[agent_index]
s3_prefix_dict[agent_name] = arg_s3_prefix[agent_index]
# download model metadata
model_metadata = ModelMetadata(bucket=arg_s3_bucket[agent_index],
s3_key=get_s3_key(arg_s3_prefix[agent_index], MODEL_METADATA_S3_POSTFIX),
region_name=args.aws_region,
local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format(agent_name))
model_metadata_info = model_metadata.get_model_metadata_info()
version = model_metadata_info[ModelMetadataKeys.VERSION.value]
# checkpoint s3 instance
checkpoint = Checkpoint(bucket=arg_s3_bucket[agent_index],
s3_prefix=arg_s3_prefix[agent_index],
region_name=args.aws_region,
agent_name=agent_name,
checkpoint_dir=args.local_model_directory)
# make coach checkpoint compatible
if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible():
checkpoint.rl_coach_checkpoint.make_compatible(checkpoint.syncfile_ready)
# get best model checkpoint string
model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint()
# Select the best checkpoint model by uploading rl coach .coach_checkpoint file
checkpoint.rl_coach_checkpoint.update(
model_checkpoint_name=model_checkpoint_name,
s3_kms_extra_args=utils.get_s3_kms_extra_args())
checkpoint_dict[agent_name] = checkpoint
agent_config = {
'model_metadata': model_metadata,
ConfigParams.CAR_CTRL_CONFIG.value: {
ConfigParams.LINK_NAME_LIST.value: [
link_name.replace('racecar', racecar_name) for link_name in LINK_NAMES],
ConfigParams.VELOCITY_LIST.value: [
velocity_topic.replace('racecar', racecar_name) for velocity_topic in VELOCITY_TOPICS],
ConfigParams.STEERING_LIST.value: [
steering_topic.replace('racecar', racecar_name) for steering_topic in STEERING_TOPICS],
ConfigParams.CHANGE_START.value: utils.str2bool(rospy.get_param('CHANGE_START_POSITION', False)),
ConfigParams.ALT_DIR.value: utils.str2bool(rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
ConfigParams.MODEL_METADATA.value: model_metadata,
ConfigParams.REWARD.value: reward_function,
ConfigParams.AGENT_NAME.value: racecar_name,
ConfigParams.VERSION.value: version,
ConfigParams.NUMBER_OF_RESETS.value: args.number_of_resets,
ConfigParams.PENALTY_SECONDS.value: args.penalty_seconds,
ConfigParams.NUMBER_OF_TRIALS.value: args.number_of_trials,
ConfigParams.IS_CONTINUOUS.value: args.is_continuous,
ConfigParams.RACE_TYPE.value: args.race_type,
ConfigParams.COLLISION_PENALTY.value: args.collision_penalty,
ConfigParams.OFF_TRACK_PENALTY.value: args.off_track_penalty,
ConfigParams.START_POSITION.value: start_positions[agent_index],
ConfigParams.DONE_CONDITION.value: done_condition}}
metrics_s3_config = {MetricsS3Keys.METRICS_BUCKET.value: metrics_s3_buckets[agent_index],
MetricsS3Keys.METRICS_KEY.value: metrics_s3_object_keys[agent_index],
# Replaced rospy.get_param('AWS_REGION') to be equal to the argument being passed
# or default argument set
MetricsS3Keys.REGION.value: args.aws_region}
aws_region = rospy.get_param('AWS_REGION', args.aws_region)
if simtrace_s3_bucket:
simtrace_video_s3_writers.append(
SimtraceVideo(upload_type=SimtraceVideoNames.SIMTRACE_EVAL.value,
bucket=simtrace_s3_bucket[agent_index],
s3_prefix=simtrace_s3_object_prefix[agent_index],
region_name=aws_region,
local_path=SIMTRACE_EVAL_LOCAL_PATH_FORMAT.format(agent_name)))
if mp4_s3_bucket:
simtrace_video_s3_writers.extend([
SimtraceVideo(upload_type=SimtraceVideoNames.PIP.value,
bucket=mp4_s3_bucket[agent_index],
s3_prefix=mp4_s3_object_prefix[agent_index],
region_name=aws_region,
local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format(agent_name)),
SimtraceVideo(upload_type=SimtraceVideoNames.DEGREE45.value,
bucket=mp4_s3_bucket[agent_index],
s3_prefix=mp4_s3_object_prefix[agent_index],
region_name=aws_region,
local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format(agent_name)),
SimtraceVideo(upload_type=SimtraceVideoNames.TOPVIEW.value,
bucket=mp4_s3_bucket[agent_index],
s3_prefix=mp4_s3_object_prefix[agent_index],
region_name=aws_region,
local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format(agent_name))])
run_phase_subject = RunPhaseSubject()
agent_list.append(create_rollout_agent(agent_config, EvalMetrics(agent_name, metrics_s3_config,
args.is_continuous),
run_phase_subject))
agent_list.append(create_obstacles_agent())
agent_list.append(create_bot_cars_agent())
# ROS service to indicate all the robomaker markov packages are ready for consumption
signal_robomaker_markov_package_ready()
PhaseObserver('/agent/training_phase', run_phase_subject)
enable_domain_randomization = utils.str2bool(rospy.get_param('ENABLE_DOMAIN_RANDOMIZATION', False))
sm_hyperparams_dict = {}
# Make the clients that will allow us to pause and unpause the physics
rospy.wait_for_service('/gazebo/pause_physics_dr')
rospy.wait_for_service('/gazebo/unpause_physics_dr')
pause_physics = ServiceProxyWrapper('/gazebo/pause_physics_dr', Empty)
unpause_physics = ServiceProxyWrapper('/gazebo/unpause_physics_dr', Empty)
graph_manager, _ = get_graph_manager(hp_dict=sm_hyperparams_dict, agent_list=agent_list,
run_phase_subject=run_phase_subject,
enable_domain_randomization=enable_domain_randomization,
done_condition=done_condition,
pause_physics=pause_physics,
unpause_physics=unpause_physics)
ds_params_instance = S3BotoDataStoreParameters(checkpoint_dict=checkpoint_dict)
graph_manager.data_store = S3BotoDataStore(params=ds_params_instance,
graph_manager=graph_manager,
ignore_lock=True)
graph_manager.env_params.seed = 0
task_parameters = TaskParameters()
task_parameters.checkpoint_restore_path = args.local_model_directory
evaluation_worker(
graph_manager=graph_manager,
number_of_trials=args.number_of_trials,
task_parameters=task_parameters,
simtrace_video_s3_writers=simtrace_video_s3_writers,
is_continuous=args.is_continuous,
park_positions=park_positions,
race_type=args.race_type,
pause_physics=pause_physics,
unpause_physics=unpause_physics
)