def main()

in src/markov/evaluation_worker.py [0:0]


def main():
    """ Main function for evaluation worker """
    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--preset',
                        help="(string) Name of a preset to run \
                             (class name from the 'presets' directory.)",
                        type=str,
                        required=False)
    parser.add_argument('--s3_bucket',
                        help='list(string) S3 bucket',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_BUCKET", ["gsaur-test"]))
    parser.add_argument('--s3_prefix',
                        help='list(string) S3 prefix',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_PREFIX", ["sagemaker"]))
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=rospy.get_param("AWS_REGION", "us-east-1"))
    parser.add_argument('--number_of_trials',
                        help='(integer) Number of trials',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_TRIALS", 10)))
    parser.add_argument('-c', '--local_model_directory',
                        help='(string) Path to a folder containing a checkpoint \
                             to restore the model from.',
                        type=str,
                        default='./checkpoint')
    parser.add_argument('--number_of_resets',
                        help='(integer) Number of resets',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_RESETS", 0)))
    parser.add_argument('--penalty_seconds',
                        help='(float) penalty second',
                        type=float,
                        default=float(rospy.get_param("PENALTY_SECONDS", 2.0)))
    parser.add_argument('--job_type',
                        help='(string) job type',
                        type=str,
                        default=rospy.get_param("JOB_TYPE", "EVALUATION"))
    parser.add_argument('--is_continuous',
                        help='(boolean) is continous after lap completion',
                        type=bool,
                        default=utils.str2bool(rospy.get_param("IS_CONTINUOUS", False)))
    parser.add_argument('--race_type',
                        help='(string) Race type',
                        type=str,
                        default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"))
    parser.add_argument('--off_track_penalty',
                        help='(float) off track penalty second',
                        type=float,
                        default=float(rospy.get_param("OFF_TRACK_PENALTY", 2.0)))
    parser.add_argument('--collision_penalty',
                        help='(float) collision penalty second',
                        type=float,
                        default=float(rospy.get_param("COLLISION_PENALTY", 5.0)))

    args = parser.parse_args()
    arg_s3_bucket = args.s3_bucket
    arg_s3_prefix = args.s3_prefix
    logger.info("S3 bucket: %s \n S3 prefix: %s", arg_s3_bucket, arg_s3_prefix)

    metrics_s3_buckets = rospy.get_param('METRICS_S3_BUCKET')
    metrics_s3_object_keys = rospy.get_param('METRICS_S3_OBJECT_KEY')

    arg_s3_bucket, arg_s3_prefix = utils.force_list(arg_s3_bucket), utils.force_list(arg_s3_prefix)
    metrics_s3_buckets = utils.force_list(metrics_s3_buckets)
    metrics_s3_object_keys = utils.force_list(metrics_s3_object_keys)

    validate_list = [arg_s3_bucket, arg_s3_prefix, metrics_s3_buckets, metrics_s3_object_keys]

    simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None)
    mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET', None)
    if simtrace_s3_bucket:
        simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX')
        simtrace_s3_bucket = utils.force_list(simtrace_s3_bucket)
        simtrace_s3_object_prefix = utils.force_list(simtrace_s3_object_prefix)
        validate_list.extend([simtrace_s3_bucket, simtrace_s3_object_prefix])
    if mp4_s3_bucket:
        mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX')
        mp4_s3_bucket = utils.force_list(mp4_s3_bucket)
        mp4_s3_object_prefix = utils.force_list(mp4_s3_object_prefix)
        validate_list.extend([mp4_s3_bucket, mp4_s3_object_prefix])

    if not all([lambda x: len(x) == len(validate_list[0]), validate_list]):
        log_and_exit("Eval worker error: Incorrect arguments passed: {}"
                         .format(validate_list),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)
    if args.number_of_resets != 0 and args.number_of_resets < MIN_RESET_COUNT:
        raise GenericRolloutException("number of resets is less than {}".format(MIN_RESET_COUNT))

    # Instantiate Cameras
    if len(arg_s3_bucket) == 1:
        configure_camera(namespaces=['racecar'])
    else:
        configure_camera(namespaces=[
            'racecar_{}'.format(str(agent_index)) for agent_index in range(len(arg_s3_bucket))])

    agent_list = list()
    s3_bucket_dict = dict()
    s3_prefix_dict = dict()
    checkpoint_dict = dict()
    simtrace_video_s3_writers = []
    start_positions = get_start_positions(len(arg_s3_bucket))
    done_condition = utils.str_to_done_condition(rospy.get_param("DONE_CONDITION", any))
    park_positions = utils.pos_2d_str_to_list(rospy.get_param("PARK_POSITIONS", []))
    # if not pass in park positions for all done condition case, use default
    if not park_positions:
        park_positions = [DEFAULT_PARK_POSITION for _ in arg_s3_bucket]
    for agent_index, _ in enumerate(arg_s3_bucket):
        agent_name = 'agent' if len(arg_s3_bucket) == 1 else 'agent_{}'.format(str(agent_index))
        racecar_name = 'racecar' if len(arg_s3_bucket) == 1 else 'racecar_{}'.format(str(agent_index))
        s3_bucket_dict[agent_name] = arg_s3_bucket[agent_index]
        s3_prefix_dict[agent_name] = arg_s3_prefix[agent_index]

        # download model metadata
        model_metadata = ModelMetadata(bucket=arg_s3_bucket[agent_index],
                                       s3_key=get_s3_key(arg_s3_prefix[agent_index], MODEL_METADATA_S3_POSTFIX),
                                       region_name=args.aws_region,
                                       local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format(agent_name))
        model_metadata_info = model_metadata.get_model_metadata_info()
        version = model_metadata_info[ModelMetadataKeys.VERSION.value]

        # checkpoint s3 instance
        checkpoint = Checkpoint(bucket=arg_s3_bucket[agent_index],
                                s3_prefix=arg_s3_prefix[agent_index],
                                region_name=args.aws_region,
                                agent_name=agent_name,
                                checkpoint_dir=args.local_model_directory)
        # make coach checkpoint compatible
        if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible():
            checkpoint.rl_coach_checkpoint.make_compatible(checkpoint.syncfile_ready)
        # get best model checkpoint string
        model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint()
        # Select the best checkpoint model by uploading rl coach .coach_checkpoint file
        checkpoint.rl_coach_checkpoint.update(
            model_checkpoint_name=model_checkpoint_name,
            s3_kms_extra_args=utils.get_s3_kms_extra_args())

        checkpoint_dict[agent_name] = checkpoint

        agent_config = {
            'model_metadata': model_metadata,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [
                    link_name.replace('racecar', racecar_name) for link_name in LINK_NAMES],
                ConfigParams.VELOCITY_LIST.value: [
                    velocity_topic.replace('racecar', racecar_name) for velocity_topic in VELOCITY_TOPICS],
                ConfigParams.STEERING_LIST.value: [
                    steering_topic.replace('racecar', racecar_name) for steering_topic in STEERING_TOPICS],
                ConfigParams.CHANGE_START.value: utils.str2bool(rospy.get_param('CHANGE_START_POSITION', False)),
                ConfigParams.ALT_DIR.value: utils.str2bool(rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
                ConfigParams.MODEL_METADATA.value: model_metadata,
                ConfigParams.REWARD.value: reward_function,
                ConfigParams.AGENT_NAME.value: racecar_name,
                ConfigParams.VERSION.value: version,
                ConfigParams.NUMBER_OF_RESETS.value: args.number_of_resets,
                ConfigParams.PENALTY_SECONDS.value: args.penalty_seconds,
                ConfigParams.NUMBER_OF_TRIALS.value: args.number_of_trials,
                ConfigParams.IS_CONTINUOUS.value: args.is_continuous,
                ConfigParams.RACE_TYPE.value: args.race_type,
                ConfigParams.COLLISION_PENALTY.value: args.collision_penalty,
                ConfigParams.OFF_TRACK_PENALTY.value: args.off_track_penalty,
                ConfigParams.START_POSITION.value: start_positions[agent_index],
                ConfigParams.DONE_CONDITION.value: done_condition}}

        metrics_s3_config = {MetricsS3Keys.METRICS_BUCKET.value: metrics_s3_buckets[agent_index],
                             MetricsS3Keys.METRICS_KEY.value: metrics_s3_object_keys[agent_index],
                             # Replaced rospy.get_param('AWS_REGION') to be equal to the argument being passed
                             # or default argument set
                             MetricsS3Keys.REGION.value: args.aws_region}
        aws_region = rospy.get_param('AWS_REGION', args.aws_region)

        if simtrace_s3_bucket:
            simtrace_video_s3_writers.append(
                SimtraceVideo(upload_type=SimtraceVideoNames.SIMTRACE_EVAL.value,
                              bucket=simtrace_s3_bucket[agent_index],
                              s3_prefix=simtrace_s3_object_prefix[agent_index],
                              region_name=aws_region,
                              local_path=SIMTRACE_EVAL_LOCAL_PATH_FORMAT.format(agent_name)))
        if mp4_s3_bucket:
            simtrace_video_s3_writers.extend([
                SimtraceVideo(upload_type=SimtraceVideoNames.PIP.value,
                              bucket=mp4_s3_bucket[agent_index],
                              s3_prefix=mp4_s3_object_prefix[agent_index],
                              region_name=aws_region,
                              local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format(agent_name)),
                SimtraceVideo(upload_type=SimtraceVideoNames.DEGREE45.value,
                              bucket=mp4_s3_bucket[agent_index],
                              s3_prefix=mp4_s3_object_prefix[agent_index],
                              region_name=aws_region,
                              local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format(agent_name)),
                SimtraceVideo(upload_type=SimtraceVideoNames.TOPVIEW.value,
                              bucket=mp4_s3_bucket[agent_index],
                              s3_prefix=mp4_s3_object_prefix[agent_index],
                              region_name=aws_region,
                              local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format(agent_name))])

        run_phase_subject = RunPhaseSubject()
        agent_list.append(create_rollout_agent(agent_config, EvalMetrics(agent_name, metrics_s3_config,
                                                                         args.is_continuous),
                                               run_phase_subject))
    agent_list.append(create_obstacles_agent())
    agent_list.append(create_bot_cars_agent())

    # ROS service to indicate all the robomaker markov packages are ready for consumption
    signal_robomaker_markov_package_ready()

    PhaseObserver('/agent/training_phase', run_phase_subject)
    enable_domain_randomization = utils.str2bool(rospy.get_param('ENABLE_DOMAIN_RANDOMIZATION', False))

    sm_hyperparams_dict = {}

    # Make the clients that will allow us to pause and unpause the physics
    rospy.wait_for_service('/gazebo/pause_physics_dr')
    rospy.wait_for_service('/gazebo/unpause_physics_dr')
    pause_physics = ServiceProxyWrapper('/gazebo/pause_physics_dr', Empty)
    unpause_physics = ServiceProxyWrapper('/gazebo/unpause_physics_dr', Empty)

    graph_manager, _ = get_graph_manager(hp_dict=sm_hyperparams_dict, agent_list=agent_list,
                                         run_phase_subject=run_phase_subject,
                                         enable_domain_randomization=enable_domain_randomization,
                                         done_condition=done_condition,
                                         pause_physics=pause_physics,
                                         unpause_physics=unpause_physics)

    ds_params_instance = S3BotoDataStoreParameters(checkpoint_dict=checkpoint_dict)

    graph_manager.data_store = S3BotoDataStore(params=ds_params_instance,
                                               graph_manager=graph_manager,
                                               ignore_lock=True)
    graph_manager.env_params.seed = 0

    task_parameters = TaskParameters()
    task_parameters.checkpoint_restore_path = args.local_model_directory

    evaluation_worker(
        graph_manager=graph_manager,
        number_of_trials=args.number_of_trials,
        task_parameters=task_parameters,
        simtrace_video_s3_writers=simtrace_video_s3_writers,
        is_continuous=args.is_continuous,
        park_positions=park_positions,
        race_type=args.race_type,
        pause_physics=pause_physics,
        unpause_physics=unpause_physics
    )