def main()

in Advanced workshops/reInvent2019-400/src/markov/rollout_worker.py [0:0]


def main():
    screen.set_use_colors(False)

    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--checkpoint_dir',
                        help='(string) Path to a folder containing a checkpoint to restore the model from.',
                        type=str,
                        default='./checkpoint')
    parser.add_argument('--s3_bucket',
                        help='(string) S3 bucket',
                        type=str,
                        default=os.environ.get("SAGEMAKER_SHARED_S3_BUCKET", "gsaur-test"))
    parser.add_argument('--s3_prefix',
                        help='(string) S3 prefix',
                        type=str,
                        default=os.environ.get("SAGEMAKER_SHARED_S3_PREFIX", "sagemaker"))
    parser.add_argument('--num-workers',
                        help="(int) The number of workers started in this pool",
                        type=int,
                        default=1)
    parser.add_argument('-r', '--redis_ip',
                        help="(string) IP or host for the redis server",
                        default='localhost',
                        type=str)
    parser.add_argument('-rp', '--redis_port',
                        help="(int) Port of the redis server",
                        default=6379,
                        type=int)
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=os.environ.get("APP_REGION", "us-east-1"))
    parser.add_argument('--reward_file_s3_key',
                        help='(string) Reward File S3 Key',
                        type=str,
                        default=os.environ.get("REWARD_FILE_S3_KEY", None))
    parser.add_argument('--model_metadata_s3_key',
                        help='(string) Model Metadata File S3 Key',
                        type=str,
                        default=os.environ.get("MODEL_METADATA_FILE_S3_KEY", None))

    args = parser.parse_args()


    print("All env vars:", os.environ)

    s3_client = SageS3Client(bucket=args.s3_bucket, s3_prefix=args.s3_prefix, aws_region=args.aws_region)

    logger = utils.Logger(__name__, logging.INFO).get_logger()

    logger.info("S3 bucket: %s" % args.s3_bucket)
    logger.info("S3 prefix: %s" % args.s3_prefix)

    # Load the model metadata
    model_metadata_local_path = os.path.join(CUSTOM_FILES_PATH, 'model_metadata.json')
    load_model_metadata(s3_client, args.model_metadata_s3_key, model_metadata_local_path)

    # Download reward function
    if not args.reward_file_s3_key:
        utils.json_format_logger("Customer reward S3 key not supplied for s3 bucket {} prefix {}. Job failed!".format(args.s3_bucket, args.s3_prefix),
                            **utils.build_system_error_dict(utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
        traceback.print_exc()
        sys.exit(1)
    download_customer_reward_function(s3_client, args.reward_file_s3_key)

    # Register the gym enviroment, this will give clients the ability to creat the enviroment object
    register(id=defaults.ENV_ID, entry_point=defaults.ENTRY_POINT,
                max_episode_steps=defaults.MAX_STEPS, reward_threshold=defaults.THRESHOLD)

    redis_ip = s3_client.get_ip()
    logger.info("Received IP from SageMaker successfully: %s" % redis_ip)

    # Download hyperparameters from SageMaker
    hyperparameters_file_success = False
    hyperparams_s3_key = os.path.normpath(args.s3_prefix + "/ip/hyperparameters.json")
    hyperparameters_file_success = s3_client.download_file(s3_key=hyperparams_s3_key,
                                                            local_path="hyperparameters.json")
    sm_hyperparams_dict = {}
    if hyperparameters_file_success:
        logger.info("Received Sagemaker hyperparameters successfully!")
        with open("hyperparameters.json") as fp:
            sm_hyperparams_dict = json.load(fp)
    else:
        logger.info("SageMaker hyperparameters not found.")

    preset_file_success, _ = download_custom_files_if_present(s3_client, args.s3_prefix)

    if preset_file_success:
        preset_location = os.path.join(CUSTOM_FILES_PATH, "preset.py")
        preset_location += ":graph_manager"
        graph_manager = short_dynamic_import(preset_location, ignore_module_case=True)
        logger.info("Using custom preset file!")
    else:
        from markov.sagemaker_graph_manager import get_graph_manager
        graph_manager, _ = get_graph_manager(**sm_hyperparams_dict)

    memory_backend_params = RedisPubSubMemoryBackendParameters(redis_address=redis_ip,
                                                                redis_port=6379,
                                                                run_type='worker',
                                                                channel=args.s3_prefix)

    ds_params_instance = S3BotoDataStoreParameters(bucket_name=args.s3_bucket,
                                                    checkpoint_dir=args.checkpoint_dir, aws_region=args.aws_region,
                                                    s3_folder=args.s3_prefix)

    data_store = S3BotoDataStore(ds_params_instance)
    data_store.graph_manager = graph_manager
    graph_manager.data_store = data_store

    rollout_worker(
        graph_manager=graph_manager,
        checkpoint_dir=args.checkpoint_dir,
        data_store=data_store,
        num_workers=args.num_workers,
        memory_backend_params = memory_backend_params
    )