def main()

in reinforcement_learning/rl_deepracer_robomaker_coach_gazebo/src/markov/rollout_worker.py [0:0]


def main():
    screen.set_use_colors(False)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--checkpoint_dir",
        help="(string) Path to a folder containing a checkpoint to restore the model from.",
        type=str,
        default="./checkpoint",
    )
    parser.add_argument(
        "--s3_bucket",
        help="(string) S3 bucket",
        type=str,
        default=rospy.get_param("SAGEMAKER_SHARED_S3_BUCKET", "gsaur-test"),
    )
    parser.add_argument(
        "--s3_prefix",
        help="(string) S3 prefix",
        type=str,
        default=rospy.get_param("SAGEMAKER_SHARED_S3_PREFIX", "sagemaker"),
    )
    parser.add_argument(
        "--num_workers",
        help="(int) The number of workers started in this pool",
        type=int,
        default=int(rospy.get_param("NUM_WORKERS", 1)),
    )
    parser.add_argument(
        "--rollout_idx", help="(int) The index of current rollout worker", type=int, default=0
    )
    parser.add_argument(
        "-r",
        "--redis_ip",
        help="(string) IP or host for the redis server",
        default="localhost",
        type=str,
    )
    parser.add_argument(
        "-rp", "--redis_port", help="(int) Port of the redis server", default=6379, type=int
    )
    parser.add_argument(
        "--aws_region",
        help="(string) AWS region",
        type=str,
        default=rospy.get_param("AWS_REGION", "us-east-1"),
    )
    parser.add_argument(
        "--reward_file_s3_key",
        help="(string) Reward File S3 Key",
        type=str,
        default=rospy.get_param("REWARD_FILE_S3_KEY", None),
    )
    parser.add_argument(
        "--model_metadata_s3_key",
        help="(string) Model Metadata File S3 Key",
        type=str,
        default=rospy.get_param("MODEL_METADATA_FILE_S3_KEY", None),
    )
    # For training job, reset is not allowed. penalty_seconds, off_track_penalty, and
    # collision_penalty will all be 0 be default
    parser.add_argument(
        "--number_of_resets",
        help="(integer) Number of resets",
        type=int,
        default=int(rospy.get_param("NUMBER_OF_RESETS", 0)),
    )
    parser.add_argument(
        "--penalty_seconds",
        help="(float) penalty second",
        type=float,
        default=float(rospy.get_param("PENALTY_SECONDS", 0.0)),
    )
    parser.add_argument(
        "--job_type",
        help="(string) job type",
        type=str,
        default=rospy.get_param("JOB_TYPE", "TRAINING"),
    )
    parser.add_argument(
        "--is_continuous",
        help="(boolean) is continous after lap completion",
        type=bool,
        default=utils.str2bool(rospy.get_param("IS_CONTINUOUS", False)),
    )
    parser.add_argument(
        "--race_type",
        help="(string) Race type",
        type=str,
        default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"),
    )
    parser.add_argument(
        "--off_track_penalty",
        help="(float) off track penalty second",
        type=float,
        default=float(rospy.get_param("OFF_TRACK_PENALTY", 0.0)),
    )
    parser.add_argument(
        "--collision_penalty",
        help="(float) collision penalty second",
        type=float,
        default=float(rospy.get_param("COLLISION_PENALTY", 0.0)),
    )

    args = parser.parse_args()

    logger.info("S3 bucket: %s", args.s3_bucket)
    logger.info("S3 prefix: %s", args.s3_prefix)

    # Download and import reward function
    # TODO: replace 'agent' with name of each agent for multi-agent training
    reward_function_file = RewardFunction(
        bucket=args.s3_bucket,
        s3_key=args.reward_file_s3_key,
        region_name=args.aws_region,
        local_path=REWARD_FUCTION_LOCAL_PATH_FORMAT.format("agent"),
    )
    reward_function = reward_function_file.get_reward_function()

    # Instantiate Cameras
    configure_camera(namespaces=["racecar"])

    preset_file_success, _ = download_custom_files_if_present(
        s3_bucket=args.s3_bucket, s3_prefix=args.s3_prefix, aws_region=args.aws_region
    )

    # download model metadata
    # TODO: replace 'agent' with name of each agent
    model_metadata = ModelMetadata(
        bucket=args.s3_bucket,
        s3_key=args.model_metadata_s3_key,
        region_name=args.aws_region,
        local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format("agent"),
    )
    model_metadata_info = model_metadata.get_model_metadata_info()
    version = model_metadata_info[ModelMetadataKeys.VERSION.value]

    agent_config = {
        "model_metadata": model_metadata,
        ConfigParams.CAR_CTRL_CONFIG.value: {
            ConfigParams.LINK_NAME_LIST.value: LINK_NAMES,
            ConfigParams.VELOCITY_LIST.value: VELOCITY_TOPICS,
            ConfigParams.STEERING_LIST.value: STEERING_TOPICS,
            ConfigParams.CHANGE_START.value: utils.str2bool(
                rospy.get_param("CHANGE_START_POSITION", True)
            ),
            ConfigParams.ALT_DIR.value: utils.str2bool(
                rospy.get_param("ALTERNATE_DRIVING_DIRECTION", False)
            ),
            ConfigParams.MODEL_METADATA.value: model_metadata,
            ConfigParams.REWARD.value: reward_function,
            ConfigParams.AGENT_NAME.value: "racecar",
            ConfigParams.VERSION.value: version,
            ConfigParams.NUMBER_OF_RESETS.value: args.number_of_resets,
            ConfigParams.PENALTY_SECONDS.value: args.penalty_seconds,
            ConfigParams.NUMBER_OF_TRIALS.value: None,
            ConfigParams.IS_CONTINUOUS.value: args.is_continuous,
            ConfigParams.RACE_TYPE.value: args.race_type,
            ConfigParams.COLLISION_PENALTY.value: args.collision_penalty,
            ConfigParams.OFF_TRACK_PENALTY.value: args.off_track_penalty,
        },
    }

    #! TODO each agent should have own s3 bucket
    metrics_key = rospy.get_param("METRICS_S3_OBJECT_KEY")
    if args.num_workers > 1 and args.rollout_idx > 0:
        key_tuple = os.path.splitext(metrics_key)
        metrics_key = "{}_{}{}".format(key_tuple[0], str(args.rollout_idx), key_tuple[1])
    metrics_s3_config = {
        MetricsS3Keys.METRICS_BUCKET.value: rospy.get_param("METRICS_S3_BUCKET"),
        MetricsS3Keys.METRICS_KEY.value: metrics_key,
        MetricsS3Keys.REGION.value: rospy.get_param("AWS_REGION"),
    }

    run_phase_subject = RunPhaseSubject()

    agent_list = list()

    # TODO: replace agent for multi agent training
    # checkpoint s3 instance
    # TODO replace agent with agent_0 and so on for multiagent case
    checkpoint = Checkpoint(
        bucket=args.s3_bucket,
        s3_prefix=args.s3_prefix,
        region_name=args.aws_region,
        agent_name="agent",
        checkpoint_dir=args.checkpoint_dir,
    )

    agent_list.append(
        create_rollout_agent(
            agent_config,
            TrainingMetrics(
                agent_name="agent",
                s3_dict_metrics=metrics_s3_config,
                deepracer_checkpoint_json=checkpoint.deepracer_checkpoint_json,
                ckpnt_dir=os.path.join(args.checkpoint_dir, "agent"),
                run_phase_sink=run_phase_subject,
                use_model_picker=(args.rollout_idx == 0),
            ),
            run_phase_subject,
        )
    )
    agent_list.append(create_obstacles_agent())
    agent_list.append(create_bot_cars_agent())
    # ROS service to indicate all the robomaker markov packages are ready for consumption
    signal_robomaker_markov_package_ready()

    PhaseObserver("/agent/training_phase", run_phase_subject)

    aws_region = rospy.get_param("AWS_REGION", args.aws_region)
    simtrace_s3_bucket = rospy.get_param("SIMTRACE_S3_BUCKET", None)
    mp4_s3_bucket = rospy.get_param("MP4_S3_BUCKET", None) if args.rollout_idx == 0 else None
    if simtrace_s3_bucket:
        simtrace_s3_object_prefix = rospy.get_param("SIMTRACE_S3_PREFIX")
        if args.num_workers > 1:
            simtrace_s3_object_prefix = os.path.join(
                simtrace_s3_object_prefix, str(args.rollout_idx)
            )
    if mp4_s3_bucket:
        mp4_s3_object_prefix = rospy.get_param("MP4_S3_OBJECT_PREFIX")

    simtrace_video_s3_writers = []
    # TODO: replace 'agent' with 'agent_0' for multi agent training and
    # mp4_s3_object_prefix, mp4_s3_bucket will be a list, so need to access with index
    if simtrace_s3_bucket:
        simtrace_video_s3_writers.append(
            SimtraceVideo(
                upload_type=SimtraceVideoNames.SIMTRACE_TRAINING.value,
                bucket=simtrace_s3_bucket,
                s3_prefix=simtrace_s3_object_prefix,
                region_name=aws_region,
                local_path=SIMTRACE_TRAINING_LOCAL_PATH_FORMAT.format("agent"),
            )
        )
    if mp4_s3_bucket:
        simtrace_video_s3_writers.extend(
            [
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.PIP.value,
                    bucket=mp4_s3_bucket,
                    s3_prefix=mp4_s3_object_prefix,
                    region_name=aws_region,
                    local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format("agent"),
                ),
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.DEGREE45.value,
                    bucket=mp4_s3_bucket,
                    s3_prefix=mp4_s3_object_prefix,
                    region_name=aws_region,
                    local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format("agent"),
                ),
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.TOPVIEW.value,
                    bucket=mp4_s3_bucket,
                    s3_prefix=mp4_s3_object_prefix,
                    region_name=aws_region,
                    local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format("agent"),
                ),
            ]
        )

    # TODO: replace 'agent' with specific agent name for multi agent training
    ip_config = IpConfig(
        bucket=args.s3_bucket,
        s3_prefix=args.s3_prefix,
        region_name=args.aws_region,
        local_path=IP_ADDRESS_LOCAL_PATH.format("agent"),
    )
    redis_ip = ip_config.get_ip_config()

    # Download hyperparameters from SageMaker shared s3 bucket
    # TODO: replace 'agent' with name of each agent
    hyperparameters = Hyperparameters(
        bucket=args.s3_bucket,
        s3_key=get_s3_key(args.s3_prefix, HYPERPARAMETER_S3_POSTFIX),
        region_name=args.aws_region,
        local_path=HYPERPARAMETER_LOCAL_PATH_FORMAT.format("agent"),
    )
    sm_hyperparams_dict = hyperparameters.get_hyperparameters_dict()

    enable_domain_randomization = utils.str2bool(
        rospy.get_param("ENABLE_DOMAIN_RANDOMIZATION", False)
    )
    # Make the clients that will allow us to pause and unpause the physics
    rospy.wait_for_service("/gazebo/pause_physics_dr")
    rospy.wait_for_service("/gazebo/unpause_physics_dr")
    pause_physics = ServiceProxyWrapper("/gazebo/pause_physics_dr", Empty)
    unpause_physics = ServiceProxyWrapper("/gazebo/unpause_physics_dr", Empty)

    if preset_file_success:
        preset_location = os.path.join(CUSTOM_FILES_PATH, "preset.py")
        preset_location += ":graph_manager"
        graph_manager = short_dynamic_import(preset_location, ignore_module_case=True)
        logger.info("Using custom preset file!")
    else:
        graph_manager, _ = get_graph_manager(
            hp_dict=sm_hyperparams_dict,
            agent_list=agent_list,
            run_phase_subject=run_phase_subject,
            enable_domain_randomization=enable_domain_randomization,
            pause_physics=pause_physics,
            unpause_physics=unpause_physics,
        )

    # If num_episodes_between_training is smaller than num_workers then cancel worker early.
    episode_steps_per_rollout = (
        graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps
    )
    # Reduce number of workers if allocated more than num_episodes_between_training
    if args.num_workers > episode_steps_per_rollout:
        logger.info(
            "Excess worker allocated. Reducing from {} to {}...".format(
                args.num_workers, episode_steps_per_rollout
            )
        )
        args.num_workers = episode_steps_per_rollout
    if args.rollout_idx >= episode_steps_per_rollout or args.rollout_idx >= args.num_workers:
        err_msg_format = "Exiting excess worker..."
        err_msg_format += (
            "(rollout_idx[{}] >= num_workers[{}] or num_episodes_between_training[{}])"
        )
        logger.info(
            err_msg_format.format(args.rollout_idx, args.num_workers, episode_steps_per_rollout)
        )
        # Close the down the job
        utils.cancel_simulation_job()

    memory_backend_params = DeepRacerRedisPubSubMemoryBackendParameters(
        redis_address=redis_ip,
        redis_port=6379,
        run_type=str(RunType.ROLLOUT_WORKER),
        channel=args.s3_prefix,
        num_workers=args.num_workers,
        rollout_idx=args.rollout_idx,
    )

    graph_manager.memory_backend_params = memory_backend_params

    checkpoint_dict = {"agent": checkpoint}
    ds_params_instance = S3BotoDataStoreParameters(checkpoint_dict=checkpoint_dict)

    graph_manager.data_store = S3BotoDataStore(ds_params_instance, graph_manager)

    task_parameters = TaskParameters()
    task_parameters.checkpoint_restore_path = args.checkpoint_dir

    rollout_worker(
        graph_manager=graph_manager,
        num_workers=args.num_workers,
        rollout_idx=args.rollout_idx,
        task_parameters=task_parameters,
        simtrace_video_s3_writers=simtrace_video_s3_writers,
        pause_physics=pause_physics,
        unpause_physics=unpause_physics,
    )