def main()

in reinforcement_learning/rl_deepracer_robomaker_coach_gazebo/src/markov/training_worker.py [0:0]


def main():
    screen.set_use_colors(False)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-pk",
        "--preset_s3_key",
        help="(string) Name of a preset to download from S3",
        type=str,
        required=False,
    )
    parser.add_argument(
        "-ek",
        "--environment_s3_key",
        help="(string) Name of an environment file to download from S3",
        type=str,
        required=False,
    )
    parser.add_argument(
        "--model_metadata_s3_key",
        help="(string) Model Metadata File S3 Key",
        type=str,
        required=False,
    )
    parser.add_argument(
        "-c",
        "--checkpoint_dir",
        help="(string) Path to a folder containing a checkpoint to write the model to.",
        type=str,
        default="./checkpoint",
    )
    parser.add_argument(
        "--pretrained_checkpoint_dir",
        help="(string) Path to a folder for downloading a pre-trained model",
        type=str,
        default=PRETRAINED_MODEL_DIR,
    )
    parser.add_argument(
        "--s3_bucket",
        help="(string) S3 bucket",
        type=str,
        default=os.environ.get("SAGEMAKER_SHARED_S3_BUCKET_PATH", "gsaur-test"),
    )
    parser.add_argument("--s3_prefix", help="(string) S3 prefix", type=str, default="sagemaker")
    parser.add_argument(
        "--framework", help="(string) tensorflow or mxnet", type=str, default="tensorflow"
    )
    parser.add_argument(
        "--pretrained_s3_bucket", help="(string) S3 bucket for pre-trained model", type=str
    )
    parser.add_argument(
        "--pretrained_s3_prefix",
        help="(string) S3 prefix for pre-trained model",
        type=str,
        default="sagemaker",
    )
    parser.add_argument(
        "--aws_region",
        help="(string) AWS region",
        type=str,
        default=os.environ.get("AWS_REGION", "us-east-1"),
    )

    args, _ = parser.parse_known_args()

    s3_client = S3Client(region_name=args.aws_region, max_retry_attempts=0)

    # download model metadata
    # TODO: replace 'agent' with name of each agent
    model_metadata_download = ModelMetadata(
        bucket=args.s3_bucket,
        s3_key=args.model_metadata_s3_key,
        region_name=args.aws_region,
        local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format("agent"),
    )
    model_metadata_info = model_metadata_download.get_model_metadata_info()
    network_type = model_metadata_info[ModelMetadataKeys.NEURAL_NETWORK.value]
    version = model_metadata_info[ModelMetadataKeys.VERSION.value]

    # upload model metadata
    model_metadata_upload = ModelMetadata(
        bucket=args.s3_bucket,
        s3_key=get_s3_key(args.s3_prefix, MODEL_METADATA_S3_POSTFIX),
        region_name=args.aws_region,
        local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format("agent"),
    )
    model_metadata_upload.persist(s3_kms_extra_args=utils.get_s3_kms_extra_args())

    shutil.copy2(model_metadata_download.local_path, SM_MODEL_OUTPUT_DIR)

    success_custom_preset = False
    if args.preset_s3_key:
        preset_local_path = "./markov/presets/preset.py"
        try:
            s3_client.download_file(
                bucket=args.s3_bucket, s3_key=args.preset_s3_key, local_path=preset_local_path
            )
            success_custom_preset = True
        except botocore.exceptions.ClientError:
            pass
        if not success_custom_preset:
            logger.info("Could not download the preset file. Using the default DeepRacer preset.")
        else:
            preset_location = "markov.presets.preset:graph_manager"
            graph_manager = short_dynamic_import(preset_location, ignore_module_case=True)
            s3_client.upload_file(
                bucket=args.s3_bucket,
                s3_key=os.path.normpath("%s/presets/preset.py" % args.s3_prefix),
                local_path=preset_local_path,
                s3_kms_extra_args=utils.get_s3_kms_extra_args(),
            )
            if success_custom_preset:
                logger.info("Using preset: %s" % args.preset_s3_key)

    if not success_custom_preset:
        params_blob = os.environ.get("SM_TRAINING_ENV", "")
        if params_blob:
            params = json.loads(params_blob)
            sm_hyperparams_dict = params["hyperparameters"]
        else:
            sm_hyperparams_dict = {}

        #! TODO each agent should have own config
        agent_config = {
            "model_metadata": model_metadata_download,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [],
                ConfigParams.VELOCITY_LIST.value: {},
                ConfigParams.STEERING_LIST.value: {},
                ConfigParams.CHANGE_START.value: None,
                ConfigParams.ALT_DIR.value: None,
                ConfigParams.MODEL_METADATA.value: model_metadata_download,
                ConfigParams.REWARD.value: None,
                ConfigParams.AGENT_NAME.value: "racecar",
            },
        }

        agent_list = list()
        agent_list.append(create_training_agent(agent_config))

        graph_manager, robomaker_hyperparams_json = get_graph_manager(
            hp_dict=sm_hyperparams_dict,
            agent_list=agent_list,
            run_phase_subject=None,
            run_type=str(RunType.TRAINER),
        )

        # Upload hyperparameters to SageMaker shared s3 bucket
        hyperparameters = Hyperparameters(
            bucket=args.s3_bucket,
            s3_key=get_s3_key(args.s3_prefix, HYPERPARAMETER_S3_POSTFIX),
            region_name=args.aws_region,
        )
        hyperparameters.persist(
            hyperparams_json=robomaker_hyperparams_json,
            s3_kms_extra_args=utils.get_s3_kms_extra_args(),
        )

        # Attach sample collector to graph_manager only if sample count > 0
        max_sample_count = int(sm_hyperparams_dict.get("max_sample_count", 0))
        if max_sample_count > 0:
            sample_collector = SampleCollector(
                bucket=args.s3_bucket,
                s3_prefix=args.s3_prefix,
                region_name=args.aws_region,
                max_sample_count=max_sample_count,
                sampling_frequency=int(sm_hyperparams_dict.get("sampling_frequency", 1)),
            )
            graph_manager.sample_collector = sample_collector

    # persist IP config from sagemaker to s3
    ip_config = IpConfig(
        bucket=args.s3_bucket, s3_prefix=args.s3_prefix, region_name=args.aws_region
    )
    ip_config.persist(s3_kms_extra_args=utils.get_s3_kms_extra_args())

    training_algorithm = model_metadata_download.training_algorithm
    output_head_format = FROZEN_HEAD_OUTPUT_GRAPH_FORMAT_MAPPING[training_algorithm]

    use_pretrained_model = args.pretrained_s3_bucket and args.pretrained_s3_prefix
    # Handle backward compatibility
    if use_pretrained_model:
        # checkpoint s3 instance for pretrained model
        # TODO: replace 'agent' for multiagent training
        checkpoint = Checkpoint(
            bucket=args.pretrained_s3_bucket,
            s3_prefix=args.pretrained_s3_prefix,
            region_name=args.aws_region,
            agent_name="agent",
            checkpoint_dir=args.pretrained_checkpoint_dir,
            output_head_format=output_head_format,
        )
        # make coach checkpoint compatible
        if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible():
            checkpoint.rl_coach_checkpoint.make_compatible(checkpoint.syncfile_ready)
        # get best model checkpoint string
        model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint()
        # Select the best checkpoint model by uploading rl coach .coach_checkpoint file
        checkpoint.rl_coach_checkpoint.update(
            model_checkpoint_name=model_checkpoint_name,
            s3_kms_extra_args=utils.get_s3_kms_extra_args(),
        )
        # add checkpoint into checkpoint_dict
        checkpoint_dict = {"agent": checkpoint}
        # load pretrained model
        ds_params_instance_pretrained = S3BotoDataStoreParameters(checkpoint_dict=checkpoint_dict)
        data_store_pretrained = S3BotoDataStore(ds_params_instance_pretrained, graph_manager, True)
        data_store_pretrained.load_from_store()

    memory_backend_params = DeepRacerRedisPubSubMemoryBackendParameters(
        redis_address="localhost",
        redis_port=6379,
        run_type=str(RunType.TRAINER),
        channel=args.s3_prefix,
        network_type=network_type,
    )

    graph_manager.memory_backend_params = memory_backend_params

    # checkpoint s3 instance for training model
    checkpoint = Checkpoint(
        bucket=args.s3_bucket,
        s3_prefix=args.s3_prefix,
        region_name=args.aws_region,
        agent_name="agent",
        checkpoint_dir=args.checkpoint_dir,
        output_head_format=output_head_format,
    )
    checkpoint_dict = {"agent": checkpoint}
    ds_params_instance = S3BotoDataStoreParameters(checkpoint_dict=checkpoint_dict)

    graph_manager.data_store_params = ds_params_instance

    graph_manager.data_store = S3BotoDataStore(ds_params_instance, graph_manager)

    task_parameters = TaskParameters()
    task_parameters.experiment_path = SM_MODEL_OUTPUT_DIR
    task_parameters.checkpoint_save_secs = 20
    if use_pretrained_model:
        task_parameters.checkpoint_restore_path = args.pretrained_checkpoint_dir
    task_parameters.checkpoint_save_dir = args.checkpoint_dir

    training_worker(
        graph_manager=graph_manager,
        task_parameters=task_parameters,
        user_batch_size=json.loads(robomaker_hyperparams_json)["batch_size"],
        user_episode_per_rollout=json.loads(robomaker_hyperparams_json)[
            "num_episodes_between_training"
        ],
        training_algorithm=training_algorithm,
    )