in robot_ws/src/rl_agent/markov/rollout_worker.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--markov-preset-file',
help="(string) Name of a preset file to run in Markov's preset directory.",
type=str,
default=os.environ.get("MARKOV_PRESET_FILE", "meiro_runner.py"))
parser.add_argument('-c', '--local-model-directory',
help='(string) Path to a folder containing a checkpoint to restore the model from.',
type=str,
default=os.environ.get("LOCAL_MODEL_DIRECTORY", "./checkpoint"))
parser.add_argument('-n', '--num-rollout-workers',
help="(int) Number of workers for multi-process based agents, e.g. A3C",
default=os.environ.get("NUMBER_OF_ROLLOUT_WORKERS", 1),
type=int)
parser.add_argument('--model-s3-bucket',
help='(string) S3 bucket where trained models are stored. It contains model checkpoints.',
type=str,
default=os.environ.get("MODEL_S3_BUCKET"))
parser.add_argument('--model-s3-prefix',
help='(string) S3 prefix where trained models are stored. It contains model checkpoints.',
type=str,
default=os.environ.get("MODEL_S3_PREFIX"))
parser.add_argument('--aws-region',
help='(string) AWS region',
type=str,
default=os.environ.get("ROS_AWS_REGION", "us-west-2"))
args = parser.parse_args()
data_store_params_instance = S3BotoDataStoreParameters(bucket_name=args.model_s3_bucket,
s3_folder=args.model_s3_prefix,
checkpoint_dir=args.local_model_directory,
aws_region=args.aws_region)
data_store = S3BotoDataStore(data_store_params_instance)
# Get the IP of the trainer machine
trainer_ip = data_store.get_ip()
print("Received IP from SageMaker successfully: %s" % trainer_ip)
preset_file_success = data_store.download_presets_if_present(PRESET_LOCAL_PATH)
if preset_file_success:
environment_file_success = data_store.download_environments_if_present(ENVIRONMENT_LOCAL_PATH)
path_and_module = PRESET_LOCAL_PATH + args.markov_preset_file + ":graph_manager"
graph_manager = short_dynamic_import(path_and_module, ignore_module_case=True)
if environment_file_success:
import robomaker.environments
print("Using custom preset file!")
elif args.markov_preset_file:
markov_path = imp.find_module("markov")[1]
preset_location = os.path.join(markov_path, "presets", args.markov_preset_file)
path_and_module = preset_location + ":graph_manager"
graph_manager = short_dynamic_import(path_and_module, ignore_module_case=True)
print("Using custom preset file from Markov presets directory!")
else:
raise ValueError("Unable to determine preset file")
memory_backend_params = RedisPubSubMemoryBackendParameters(redis_address=trainer_ip,
redis_port=TRAINER_REDIS_PORT,
run_type='worker',
channel=args.model_s3_prefix)
graph_manager.agent_params.memory.register_var('memory_backend_params', memory_backend_params)
graph_manager.data_store_params = data_store_params_instance
graph_manager.data_store = data_store
utils.wait_for_checkpoint(checkpoint_dir=args.local_model_directory, data_store=data_store)
rollout_worker(
graph_manager=graph_manager,
checkpoint_dir=args.local_model_directory,
data_store=data_store,
num_workers=args.num_rollout_workers
)