in reinforcement_learning/rl_deepracer_robomaker_coach_gazebo/src/markov/rollout_worker.py [0:0]
def main():
screen.set_use_colors(False)
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--checkpoint_dir",
help="(string) Path to a folder containing a checkpoint to restore the model from.",
type=str,
default="./checkpoint",
)
parser.add_argument(
"--s3_bucket",
help="(string) S3 bucket",
type=str,
default=rospy.get_param("SAGEMAKER_SHARED_S3_BUCKET", "gsaur-test"),
)
parser.add_argument(
"--s3_prefix",
help="(string) S3 prefix",
type=str,
default=rospy.get_param("SAGEMAKER_SHARED_S3_PREFIX", "sagemaker"),
)
parser.add_argument(
"--num_workers",
help="(int) The number of workers started in this pool",
type=int,
default=int(rospy.get_param("NUM_WORKERS", 1)),
)
parser.add_argument(
"--rollout_idx", help="(int) The index of current rollout worker", type=int, default=0
)
parser.add_argument(
"-r",
"--redis_ip",
help="(string) IP or host for the redis server",
default="localhost",
type=str,
)
parser.add_argument(
"-rp", "--redis_port", help="(int) Port of the redis server", default=6379, type=int
)
parser.add_argument(
"--aws_region",
help="(string) AWS region",
type=str,
default=rospy.get_param("AWS_REGION", "us-east-1"),
)
parser.add_argument(
"--reward_file_s3_key",
help="(string) Reward File S3 Key",
type=str,
default=rospy.get_param("REWARD_FILE_S3_KEY", None),
)
parser.add_argument(
"--model_metadata_s3_key",
help="(string) Model Metadata File S3 Key",
type=str,
default=rospy.get_param("MODEL_METADATA_FILE_S3_KEY", None),
)
# For training job, reset is not allowed. penalty_seconds, off_track_penalty, and
# collision_penalty will all be 0 be default
parser.add_argument(
"--number_of_resets",
help="(integer) Number of resets",
type=int,
default=int(rospy.get_param("NUMBER_OF_RESETS", 0)),
)
parser.add_argument(
"--penalty_seconds",
help="(float) penalty second",
type=float,
default=float(rospy.get_param("PENALTY_SECONDS", 0.0)),
)
parser.add_argument(
"--job_type",
help="(string) job type",
type=str,
default=rospy.get_param("JOB_TYPE", "TRAINING"),
)
parser.add_argument(
"--is_continuous",
help="(boolean) is continous after lap completion",
type=bool,
default=utils.str2bool(rospy.get_param("IS_CONTINUOUS", False)),
)
parser.add_argument(
"--race_type",
help="(string) Race type",
type=str,
default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"),
)
parser.add_argument(
"--off_track_penalty",
help="(float) off track penalty second",
type=float,
default=float(rospy.get_param("OFF_TRACK_PENALTY", 0.0)),
)
parser.add_argument(
"--collision_penalty",
help="(float) collision penalty second",
type=float,
default=float(rospy.get_param("COLLISION_PENALTY", 0.0)),
)
args = parser.parse_args()
logger.info("S3 bucket: %s", args.s3_bucket)
logger.info("S3 prefix: %s", args.s3_prefix)
# Download and import reward function
# TODO: replace 'agent' with name of each agent for multi-agent training
reward_function_file = RewardFunction(
bucket=args.s3_bucket,
s3_key=args.reward_file_s3_key,
region_name=args.aws_region,
local_path=REWARD_FUCTION_LOCAL_PATH_FORMAT.format("agent"),
)
reward_function = reward_function_file.get_reward_function()
# Instantiate Cameras
configure_camera(namespaces=["racecar"])
preset_file_success, _ = download_custom_files_if_present(
s3_bucket=args.s3_bucket, s3_prefix=args.s3_prefix, aws_region=args.aws_region
)
# download model metadata
# TODO: replace 'agent' with name of each agent
model_metadata = ModelMetadata(
bucket=args.s3_bucket,
s3_key=args.model_metadata_s3_key,
region_name=args.aws_region,
local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format("agent"),
)
model_metadata_info = model_metadata.get_model_metadata_info()
version = model_metadata_info[ModelMetadataKeys.VERSION.value]
agent_config = {
"model_metadata": model_metadata,
ConfigParams.CAR_CTRL_CONFIG.value: {
ConfigParams.LINK_NAME_LIST.value: LINK_NAMES,
ConfigParams.VELOCITY_LIST.value: VELOCITY_TOPICS,
ConfigParams.STEERING_LIST.value: STEERING_TOPICS,
ConfigParams.CHANGE_START.value: utils.str2bool(
rospy.get_param("CHANGE_START_POSITION", True)
),
ConfigParams.ALT_DIR.value: utils.str2bool(
rospy.get_param("ALTERNATE_DRIVING_DIRECTION", False)
),
ConfigParams.MODEL_METADATA.value: model_metadata,
ConfigParams.REWARD.value: reward_function,
ConfigParams.AGENT_NAME.value: "racecar",
ConfigParams.VERSION.value: version,
ConfigParams.NUMBER_OF_RESETS.value: args.number_of_resets,
ConfigParams.PENALTY_SECONDS.value: args.penalty_seconds,
ConfigParams.NUMBER_OF_TRIALS.value: None,
ConfigParams.IS_CONTINUOUS.value: args.is_continuous,
ConfigParams.RACE_TYPE.value: args.race_type,
ConfigParams.COLLISION_PENALTY.value: args.collision_penalty,
ConfigParams.OFF_TRACK_PENALTY.value: args.off_track_penalty,
},
}
#! TODO each agent should have own s3 bucket
metrics_key = rospy.get_param("METRICS_S3_OBJECT_KEY")
if args.num_workers > 1 and args.rollout_idx > 0:
key_tuple = os.path.splitext(metrics_key)
metrics_key = "{}_{}{}".format(key_tuple[0], str(args.rollout_idx), key_tuple[1])
metrics_s3_config = {
MetricsS3Keys.METRICS_BUCKET.value: rospy.get_param("METRICS_S3_BUCKET"),
MetricsS3Keys.METRICS_KEY.value: metrics_key,
MetricsS3Keys.REGION.value: rospy.get_param("AWS_REGION"),
}
run_phase_subject = RunPhaseSubject()
agent_list = list()
# TODO: replace agent for multi agent training
# checkpoint s3 instance
# TODO replace agent with agent_0 and so on for multiagent case
checkpoint = Checkpoint(
bucket=args.s3_bucket,
s3_prefix=args.s3_prefix,
region_name=args.aws_region,
agent_name="agent",
checkpoint_dir=args.checkpoint_dir,
)
agent_list.append(
create_rollout_agent(
agent_config,
TrainingMetrics(
agent_name="agent",
s3_dict_metrics=metrics_s3_config,
deepracer_checkpoint_json=checkpoint.deepracer_checkpoint_json,
ckpnt_dir=os.path.join(args.checkpoint_dir, "agent"),
run_phase_sink=run_phase_subject,
use_model_picker=(args.rollout_idx == 0),
),
run_phase_subject,
)
)
agent_list.append(create_obstacles_agent())
agent_list.append(create_bot_cars_agent())
# ROS service to indicate all the robomaker markov packages are ready for consumption
signal_robomaker_markov_package_ready()
PhaseObserver("/agent/training_phase", run_phase_subject)
aws_region = rospy.get_param("AWS_REGION", args.aws_region)
simtrace_s3_bucket = rospy.get_param("SIMTRACE_S3_BUCKET", None)
mp4_s3_bucket = rospy.get_param("MP4_S3_BUCKET", None) if args.rollout_idx == 0 else None
if simtrace_s3_bucket:
simtrace_s3_object_prefix = rospy.get_param("SIMTRACE_S3_PREFIX")
if args.num_workers > 1:
simtrace_s3_object_prefix = os.path.join(
simtrace_s3_object_prefix, str(args.rollout_idx)
)
if mp4_s3_bucket:
mp4_s3_object_prefix = rospy.get_param("MP4_S3_OBJECT_PREFIX")
simtrace_video_s3_writers = []
# TODO: replace 'agent' with 'agent_0' for multi agent training and
# mp4_s3_object_prefix, mp4_s3_bucket will be a list, so need to access with index
if simtrace_s3_bucket:
simtrace_video_s3_writers.append(
SimtraceVideo(
upload_type=SimtraceVideoNames.SIMTRACE_TRAINING.value,
bucket=simtrace_s3_bucket,
s3_prefix=simtrace_s3_object_prefix,
region_name=aws_region,
local_path=SIMTRACE_TRAINING_LOCAL_PATH_FORMAT.format("agent"),
)
)
if mp4_s3_bucket:
simtrace_video_s3_writers.extend(
[
SimtraceVideo(
upload_type=SimtraceVideoNames.PIP.value,
bucket=mp4_s3_bucket,
s3_prefix=mp4_s3_object_prefix,
region_name=aws_region,
local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format("agent"),
),
SimtraceVideo(
upload_type=SimtraceVideoNames.DEGREE45.value,
bucket=mp4_s3_bucket,
s3_prefix=mp4_s3_object_prefix,
region_name=aws_region,
local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format("agent"),
),
SimtraceVideo(
upload_type=SimtraceVideoNames.TOPVIEW.value,
bucket=mp4_s3_bucket,
s3_prefix=mp4_s3_object_prefix,
region_name=aws_region,
local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format("agent"),
),
]
)
# TODO: replace 'agent' with specific agent name for multi agent training
ip_config = IpConfig(
bucket=args.s3_bucket,
s3_prefix=args.s3_prefix,
region_name=args.aws_region,
local_path=IP_ADDRESS_LOCAL_PATH.format("agent"),
)
redis_ip = ip_config.get_ip_config()
# Download hyperparameters from SageMaker shared s3 bucket
# TODO: replace 'agent' with name of each agent
hyperparameters = Hyperparameters(
bucket=args.s3_bucket,
s3_key=get_s3_key(args.s3_prefix, HYPERPARAMETER_S3_POSTFIX),
region_name=args.aws_region,
local_path=HYPERPARAMETER_LOCAL_PATH_FORMAT.format("agent"),
)
sm_hyperparams_dict = hyperparameters.get_hyperparameters_dict()
enable_domain_randomization = utils.str2bool(
rospy.get_param("ENABLE_DOMAIN_RANDOMIZATION", False)
)
# Make the clients that will allow us to pause and unpause the physics
rospy.wait_for_service("/gazebo/pause_physics_dr")
rospy.wait_for_service("/gazebo/unpause_physics_dr")
pause_physics = ServiceProxyWrapper("/gazebo/pause_physics_dr", Empty)
unpause_physics = ServiceProxyWrapper("/gazebo/unpause_physics_dr", Empty)
if preset_file_success:
preset_location = os.path.join(CUSTOM_FILES_PATH, "preset.py")
preset_location += ":graph_manager"
graph_manager = short_dynamic_import(preset_location, ignore_module_case=True)
logger.info("Using custom preset file!")
else:
graph_manager, _ = get_graph_manager(
hp_dict=sm_hyperparams_dict,
agent_list=agent_list,
run_phase_subject=run_phase_subject,
enable_domain_randomization=enable_domain_randomization,
pause_physics=pause_physics,
unpause_physics=unpause_physics,
)
# If num_episodes_between_training is smaller than num_workers then cancel worker early.
episode_steps_per_rollout = (
graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps
)
# Reduce number of workers if allocated more than num_episodes_between_training
if args.num_workers > episode_steps_per_rollout:
logger.info(
"Excess worker allocated. Reducing from {} to {}...".format(
args.num_workers, episode_steps_per_rollout
)
)
args.num_workers = episode_steps_per_rollout
if args.rollout_idx >= episode_steps_per_rollout or args.rollout_idx >= args.num_workers:
err_msg_format = "Exiting excess worker..."
err_msg_format += (
"(rollout_idx[{}] >= num_workers[{}] or num_episodes_between_training[{}])"
)
logger.info(
err_msg_format.format(args.rollout_idx, args.num_workers, episode_steps_per_rollout)
)
# Close the down the job
utils.cancel_simulation_job()
memory_backend_params = DeepRacerRedisPubSubMemoryBackendParameters(
redis_address=redis_ip,
redis_port=6379,
run_type=str(RunType.ROLLOUT_WORKER),
channel=args.s3_prefix,
num_workers=args.num_workers,
rollout_idx=args.rollout_idx,
)
graph_manager.memory_backend_params = memory_backend_params
checkpoint_dict = {"agent": checkpoint}
ds_params_instance = S3BotoDataStoreParameters(checkpoint_dict=checkpoint_dict)
graph_manager.data_store = S3BotoDataStore(ds_params_instance, graph_manager)
task_parameters = TaskParameters()
task_parameters.checkpoint_restore_path = args.checkpoint_dir
rollout_worker(
graph_manager=graph_manager,
num_workers=args.num_workers,
rollout_idx=args.rollout_idx,
task_parameters=task_parameters,
simtrace_video_s3_writers=simtrace_video_s3_writers,
pause_physics=pause_physics,
unpause_physics=unpause_physics,
)