in Advanced workshops/AI Driving Olympics 2019/challenge_train_w_PPO/src/markov/rollout_worker.py [0:0]
def main():
screen.set_use_colors(False)
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--checkpoint_dir',
help='(string) Path to a folder containing a checkpoint to restore the model from.',
type=str,
default='./checkpoint')
parser.add_argument('--s3_bucket',
help='(string) S3 bucket',
type=str,
default=os.environ.get("SAGEMAKER_SHARED_S3_BUCKET", "gsaur-test"))
parser.add_argument('--s3_prefix',
help='(string) S3 prefix',
type=str,
default=os.environ.get("SAGEMAKER_SHARED_S3_PREFIX", "sagemaker"))
parser.add_argument('--num-workers',
help="(int) The number of workers started in this pool",
type=int,
default=1)
parser.add_argument('-r', '--redis_ip',
help="(string) IP or host for the redis server",
default='localhost',
type=str)
parser.add_argument('-rp', '--redis_port',
help="(int) Port of the redis server",
default=6379,
type=int)
parser.add_argument('--aws_region',
help='(string) AWS region',
type=str,
default=os.environ.get("APP_REGION", "us-east-1"))
parser.add_argument('--reward_file_s3_key',
help='(string) Reward File S3 Key',
type=str,
default=os.environ.get("REWARD_FILE_S3_KEY", None))
parser.add_argument('--model_metadata_s3_key',
help='(string) Model Metadata File S3 Key',
type=str,
default=os.environ.get("MODEL_METADATA_FILE_S3_KEY", None))
args = parser.parse_args()
s3_client = SageS3Client(bucket=args.s3_bucket, s3_prefix=args.s3_prefix, aws_region=args.aws_region)
logger.info("S3 bucket: %s" % args.s3_bucket)
logger.info("S3 prefix: %s" % args.s3_prefix)
# Load the model metadata
model_metadata_local_path = os.path.join(CUSTOM_FILES_PATH, 'model_metadata.json')
load_model_metadata(s3_client, args.model_metadata_s3_key, model_metadata_local_path)
# Download reward function
if not args.reward_file_s3_key:
utils.json_format_logger("Customer reward S3 key not supplied for s3 bucket {} prefix {}. Job failed!".format(args.s3_bucket, args.s3_prefix),
**utils.build_system_error_dict(utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
traceback.print_exc()
sys.exit(1)
download_customer_reward_function(s3_client, args.reward_file_s3_key)
# Register the gym enviroment, this will give clients the ability to creat the enviroment object
register(id=defaults.ENV_ID, entry_point=defaults.ENTRY_POINT,
max_episode_steps=defaults.MAX_STEPS, reward_threshold=defaults.THRESHOLD)
redis_ip = s3_client.get_ip()
logger.info("Received IP from SageMaker successfully: %s" % redis_ip)
# Download hyperparameters from SageMaker
hyperparameters_file_success = False
hyperparams_s3_key = os.path.normpath(args.s3_prefix + "/ip/hyperparameters.json")
hyperparameters_file_success = s3_client.download_file(s3_key=hyperparams_s3_key,
local_path="hyperparameters.json")
sm_hyperparams_dict = {}
if hyperparameters_file_success:
logger.info("Received Sagemaker hyperparameters successfully!")
with open("hyperparameters.json") as fp:
sm_hyperparams_dict = json.load(fp)
else:
logger.info("SageMaker hyperparameters not found.")
preset_file_success, _ = download_custom_files_if_present(s3_client, args.s3_prefix)
if preset_file_success:
preset_location = os.path.join(CUSTOM_FILES_PATH, "preset.py")
preset_location += ":graph_manager"
graph_manager = short_dynamic_import(preset_location, ignore_module_case=True)
logger.info("Using custom preset file!")
else:
from markov.sagemaker_graph_manager import get_graph_manager
graph_manager, _ = get_graph_manager(**sm_hyperparams_dict)
memory_backend_params = RedisPubSubMemoryBackendParameters(redis_address=redis_ip,
redis_port=6379,
run_type='worker',
channel=args.s3_prefix)
ds_params_instance = S3BotoDataStoreParameters(bucket_name=args.s3_bucket,
checkpoint_dir=args.checkpoint_dir, aws_region=args.aws_region,
s3_folder=args.s3_prefix)
data_store = S3BotoDataStore(ds_params_instance)
data_store.graph_manager = graph_manager
graph_manager.data_store = data_store
rollout_worker(
graph_manager=graph_manager,
checkpoint_dir=args.checkpoint_dir,
data_store=data_store,
num_workers=args.num_workers,
memory_backend_params = memory_backend_params
)