in Advanced workshops/AI Driving Olympics 2019/challenge_train_DQN/src/training_worker.py [0:0]
def main():
screen.set_use_colors(False)
parser = argparse.ArgumentParser()
parser.add_argument('-pk', '--preset_s3_key',
help="(string) Name of a preset to download from S3",
type=str,
required=False)
parser.add_argument('-ek', '--environment_s3_key',
help="(string) Name of an environment file to download from S3",
type=str,
required=False)
parser.add_argument('--model_metadata_s3_key',
help="(string) Model Metadata File S3 Key",
type=str,
required=False)
parser.add_argument('-c', '--checkpoint-dir',
help='(string) Path to a folder containing a checkpoint to write the model to.',
type=str,
default='./checkpoint')
parser.add_argument('--pretrained-checkpoint-dir',
help='(string) Path to a folder for downloading a pre-trained model',
type=str,
default=PRETRAINED_MODEL_DIR)
parser.add_argument('--s3_bucket',
help='(string) S3 bucket',
type=str,
default=os.environ.get("SAGEMAKER_SHARED_S3_BUCKET_PATH", "gsaur-test"))
parser.add_argument('--s3_prefix',
help='(string) S3 prefix',
type=str,
default='sagemaker')
parser.add_argument('--framework',
help='(string) tensorflow or mxnet',
type=str,
default='tensorflow')
parser.add_argument('--pretrained_s3_bucket',
help='(string) S3 bucket for pre-trained model',
type=str)
parser.add_argument('--pretrained_s3_prefix',
help='(string) S3 prefix for pre-trained model',
type=str,
default='sagemaker')
parser.add_argument('--aws_region',
help='(string) AWS region',
type=str,
default=os.environ.get("AWS_REGION", "us-east-1"))
start_redis_server()
args, unknown = parser.parse_known_args()
s3_client = SageS3Client(bucket=args.s3_bucket, s3_prefix=args.s3_prefix, aws_region=args.aws_region)
# Load the model metadata
model_metadata_local_path = os.path.join(CUSTOM_FILES_PATH, 'model_metadata.json')
load_model_metadata(s3_client, args.model_metadata_s3_key, model_metadata_local_path)
s3_client.upload_file(os.path.normpath("%s/model/model_metadata.json" % args.s3_prefix), model_metadata_local_path)
shutil.copy2(model_metadata_local_path, SM_MODEL_OUTPUT_DIR)
# Register the gym enviroment, this will give clients the ability to creat the enviroment object
register(id=defaults.ENV_ID, entry_point=defaults.ENTRY_POINT,
max_episode_steps=defaults.MAX_STEPS, reward_threshold=defaults.THRESHOLD)
success_custom_preset = False
if args.preset_s3_key:
preset_local_path = "./markov/presets/preset.py"
success_custom_preset = s3_client.download_file(s3_key=args.preset_s3_key, local_path=preset_local_path)
if not success_custom_preset:
logger.info("Could not download the preset file. Using the default DeepRacer preset.")
else:
preset_location = "markov.presets.preset:graph_manager"
graph_manager = short_dynamic_import(preset_location, ignore_module_case=True)
success_custom_preset = s3_client.upload_file(
s3_key=os.path.normpath("%s/presets/preset.py" % args.s3_prefix), local_path=preset_local_path)
if success_custom_preset:
logger.info("Using preset: %s" % args.preset_s3_key)
if not success_custom_preset:
from markov.sagemaker_graph_manager import get_graph_manager
params_blob = os.environ.get('SM_TRAINING_ENV', '')
if params_blob:
params = json.loads(params_blob)
sm_hyperparams_dict = params["hyperparameters"]
else:
sm_hyperparams_dict = {}
graph_manager, robomaker_hyperparams_json = get_graph_manager(**sm_hyperparams_dict)
s3_client.upload_hyperparameters(robomaker_hyperparams_json)
logger.info("Uploaded hyperparameters.json to S3")
host_ip_address = get_ip_from_host()
s3_client.write_ip_config(host_ip_address)
logger.info("Uploaded IP address information to S3: %s" % host_ip_address)
use_pretrained_model = False
if args.pretrained_s3_bucket and args.pretrained_s3_prefix:
s3_client_pretrained = SageS3Client(bucket=args.pretrained_s3_bucket,
s3_prefix=args.pretrained_s3_prefix,
aws_region=args.aws_region)
use_pretrained_model = s3_client_pretrained.download_model(args.pretrained_checkpoint_dir)
memory_backend_params = RedisPubSubMemoryBackendParameters(redis_address="localhost",
redis_port=6379,
run_type='trainer',
channel=args.s3_prefix)
ds_params_instance = S3BotoDataStoreParameters(bucket_name=args.s3_bucket,
checkpoint_dir=args.checkpoint_dir, aws_region=args.aws_region,
s3_folder=args.s3_prefix)
graph_manager.data_store_params = ds_params_instance
data_store = S3BotoDataStore(ds_params_instance)
data_store.graph_manager = graph_manager
graph_manager.data_store = data_store
training_worker(
graph_manager=graph_manager,
checkpoint_dir=args.checkpoint_dir,
use_pretrained_model=use_pretrained_model,
framework=args.framework,
memory_backend_params=memory_backend_params
)