in src/markov/rollout_worker.py [0:0]
def main():
screen.set_use_colors(False)
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--checkpoint_dir',
help='(string) Path to a folder containing a checkpoint to restore the model from.',
type=str,
default='./checkpoint_robomaker')
parser.add_argument('--s3_bucket',
help='(string) S3 bucket',
type=str,
default=rospy.get_param("SAGEMAKER_SHARED_S3_BUCKET", "gsaur-test"))
parser.add_argument('--s3_prefix',
help='(string) S3 prefix',
type=str,
default=rospy.get_param("SAGEMAKER_SHARED_S3_PREFIX", "sagemaker"))
parser.add_argument('--num_workers',
help="(int) The number of workers started in this pool",
type=int,
default=int(rospy.get_param("NUM_WORKERS", 1)))
parser.add_argument('--rollout_idx',
help="(int) The index of current rollout worker",
type=int,
default=0)
parser.add_argument('-r', '--redis_ip',
help="(string) IP or host for the redis server",
default='localhost',
type=str)
parser.add_argument('-rp', '--redis_port',
help="(int) Port of the redis server",
default=6379,
type=int)
parser.add_argument('--aws_region',
help='(string) AWS region',
type=str,
default=rospy.get_param("AWS_REGION", "us-east-1"))
parser.add_argument('--reward_file_s3_key',
help='(string) Reward File S3 Key',
type=str,
default=rospy.get_param("REWARD_FILE_S3_KEY", None))
parser.add_argument('--model_metadata_s3_key',
help='(string) Model Metadata File S3 Key',
type=str,
default=rospy.get_param("MODEL_METADATA_FILE_S3_KEY", None))
# For training job, reset is not allowed. penalty_seconds, off_track_penalty, and
# collision_penalty will all be 0 be default
parser.add_argument('--number_of_resets',
help='(integer) Number of resets',
type=int,
default=int(rospy.get_param("NUMBER_OF_RESETS", 0)))
parser.add_argument('--penalty_seconds',
help='(float) penalty second',
type=float,
default=float(rospy.get_param("PENALTY_SECONDS", 0.0)))
parser.add_argument('--job_type',
help='(string) job type',
type=str,
default=rospy.get_param("JOB_TYPE", "TRAINING"))
parser.add_argument('--is_continuous',
help='(boolean) is continous after lap completion',
type=bool,
default=utils.str2bool(rospy.get_param("IS_CONTINUOUS", False)))
parser.add_argument('--race_type',
help='(string) Race type',
type=str,
default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"))
parser.add_argument('--off_track_penalty',
help='(float) off track penalty second',
type=float,
default=float(rospy.get_param("OFF_TRACK_PENALTY", 0.0)))
parser.add_argument('--collision_penalty',
help='(float) collision penalty second',
type=float,
default=float(rospy.get_param("COLLISION_PENALTY", 0.0)))
args = parser.parse_args()
logger.info("S3 bucket: %s", args.s3_bucket)
logger.info("S3 prefix: %s", args.s3_prefix)
# Download and import reward function
# TODO: replace 'agent' with name of each agent for multi-agent training
reward_function_file = RewardFunction(bucket=args.s3_bucket,
s3_key=args.reward_file_s3_key,
region_name=args.aws_region,
local_path=REWARD_FUCTION_LOCAL_PATH_FORMAT.format('agent'))
reward_function = reward_function_file.get_reward_function()
# Instantiate Cameras
configure_camera(namespaces=['racecar'])
preset_file_success, _ = download_custom_files_if_present(s3_bucket=args.s3_bucket,
s3_prefix=args.s3_prefix,
aws_region=args.aws_region)
# download model metadata
# TODO: replace 'agent' with name of each agent
model_metadata = ModelMetadata(bucket=args.s3_bucket,
s3_key=args.model_metadata_s3_key,
region_name=args.aws_region,
local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format('agent'))
model_metadata_info = model_metadata.get_model_metadata_info()
version = model_metadata_info[ModelMetadataKeys.VERSION.value]
agent_config = {
'model_metadata': model_metadata,
ConfigParams.CAR_CTRL_CONFIG.value: {
ConfigParams.LINK_NAME_LIST.value: LINK_NAMES,
ConfigParams.VELOCITY_LIST.value : VELOCITY_TOPICS,
ConfigParams.STEERING_LIST.value : STEERING_TOPICS,
ConfigParams.CHANGE_START.value : utils.str2bool(rospy.get_param('CHANGE_START_POSITION', True)),
ConfigParams.ALT_DIR.value : utils.str2bool(rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
ConfigParams.MODEL_METADATA.value : model_metadata,
ConfigParams.REWARD.value : reward_function,
ConfigParams.AGENT_NAME.value : 'racecar',
ConfigParams.VERSION.value : version,
ConfigParams.NUMBER_OF_RESETS.value: args.number_of_resets,
ConfigParams.PENALTY_SECONDS.value: args.penalty_seconds,
ConfigParams.NUMBER_OF_TRIALS.value: None,
ConfigParams.IS_CONTINUOUS.value: args.is_continuous,
ConfigParams.RACE_TYPE.value: args.race_type,
ConfigParams.COLLISION_PENALTY.value: args.collision_penalty,
ConfigParams.OFF_TRACK_PENALTY.value: args.off_track_penalty
}
}
#! TODO each agent should have own s3 bucket
metrics_key = rospy.get_param('METRICS_S3_OBJECT_KEY')
if args.num_workers > 1 and args.rollout_idx > 0:
key_tuple = os.path.splitext(metrics_key)
metrics_key = "{}_{}{}".format(key_tuple[0],
str(args.rollout_idx),
key_tuple[1])
metrics_s3_config = {MetricsS3Keys.METRICS_BUCKET.value: rospy.get_param('METRICS_S3_BUCKET'),
MetricsS3Keys.METRICS_KEY.value: metrics_key,
MetricsS3Keys.REGION.value: rospy.get_param('AWS_REGION')}
run_phase_subject = RunPhaseSubject()
agent_list = list()
#TODO: replace agent for multi agent training
# checkpoint s3 instance
# TODO replace agent with agent_0 and so on for multiagent case
checkpoint = Checkpoint(bucket=args.s3_bucket,
s3_prefix=args.s3_prefix,
region_name=args.aws_region,
agent_name='agent',
checkpoint_dir=args.checkpoint_dir)
agent_list.append(create_rollout_agent(agent_config,
TrainingMetrics(agent_name='agent',
s3_dict_metrics=metrics_s3_config,
deepracer_checkpoint_json=checkpoint.deepracer_checkpoint_json,
ckpnt_dir=os.path.join(args.checkpoint_dir, 'agent'),
run_phase_sink=run_phase_subject,
use_model_picker=(args.rollout_idx == 0)),
run_phase_subject))
agent_list.append(create_obstacles_agent())
agent_list.append(create_bot_cars_agent())
# ROS service to indicate all the robomaker markov packages are ready for consumption
signal_robomaker_markov_package_ready()
PhaseObserver('/agent/training_phase', run_phase_subject)
aws_region = rospy.get_param('AWS_REGION', args.aws_region)
simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None)
mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET', None) if args.rollout_idx == 0 else None
if simtrace_s3_bucket:
simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX')
if args.num_workers > 1:
simtrace_s3_object_prefix = os.path.join(simtrace_s3_object_prefix, str(args.rollout_idx))
if mp4_s3_bucket:
mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX')
simtrace_video_s3_writers = []
#TODO: replace 'agent' with 'agent_0' for multi agent training and
# mp4_s3_object_prefix, mp4_s3_bucket will be a list, so need to access with index
if simtrace_s3_bucket:
simtrace_video_s3_writers.append(
SimtraceVideo(upload_type=SimtraceVideoNames.SIMTRACE_TRAINING.value,
bucket=simtrace_s3_bucket,
s3_prefix=simtrace_s3_object_prefix,
region_name=aws_region,
local_path=SIMTRACE_TRAINING_LOCAL_PATH_FORMAT.format('agent')))
if mp4_s3_bucket:
simtrace_video_s3_writers.extend([
SimtraceVideo(upload_type=SimtraceVideoNames.PIP.value,
bucket=mp4_s3_bucket,
s3_prefix=mp4_s3_object_prefix,
region_name=aws_region,
local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format('agent')),
SimtraceVideo(upload_type=SimtraceVideoNames.DEGREE45.value,
bucket=mp4_s3_bucket,
s3_prefix=mp4_s3_object_prefix,
region_name=aws_region,
local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format('agent')),
SimtraceVideo(upload_type=SimtraceVideoNames.TOPVIEW.value,
bucket=mp4_s3_bucket,
s3_prefix=mp4_s3_object_prefix,
region_name=aws_region,
local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format('agent'))])
# TODO: replace 'agent' with specific agent name for multi agent training
ip_config = IpConfig(bucket=args.s3_bucket,
s3_prefix=args.s3_prefix,
region_name=args.aws_region,
local_path=IP_ADDRESS_LOCAL_PATH.format('agent'))
redis_ip = ip_config.get_ip_config()
# Download hyperparameters from SageMaker shared s3 bucket
# TODO: replace 'agent' with name of each agent
hyperparameters = Hyperparameters(bucket=args.s3_bucket,
s3_key=get_s3_key(args.s3_prefix, HYPERPARAMETER_S3_POSTFIX),
region_name=args.aws_region,
local_path=HYPERPARAMETER_LOCAL_PATH_FORMAT.format('agent'))
sm_hyperparams_dict = hyperparameters.get_hyperparameters_dict()
enable_domain_randomization = utils.str2bool(rospy.get_param('ENABLE_DOMAIN_RANDOMIZATION', False))
# Make the clients that will allow us to pause and unpause the physics
rospy.wait_for_service('/gazebo/pause_physics_dr')
rospy.wait_for_service('/gazebo/unpause_physics_dr')
pause_physics = ServiceProxyWrapper('/gazebo/pause_physics_dr', Empty)
unpause_physics = ServiceProxyWrapper('/gazebo/unpause_physics_dr', Empty)
if preset_file_success:
preset_location = os.path.join(CUSTOM_FILES_PATH, "preset.py")
preset_location += ":graph_manager"
graph_manager = short_dynamic_import(preset_location, ignore_module_case=True)
logger.info("Using custom preset file!")
else:
graph_manager, _ = get_graph_manager(hp_dict=sm_hyperparams_dict, agent_list=agent_list,
run_phase_subject=run_phase_subject,
enable_domain_randomization=enable_domain_randomization,
pause_physics=pause_physics,
unpause_physics=unpause_physics)
# If num_episodes_between_training is smaller than num_workers then cancel worker early.
episode_steps_per_rollout = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps
# Reduce number of workers if allocated more than num_episodes_between_training
if args.num_workers > episode_steps_per_rollout:
logger.info("Excess worker allocated. Reducing from {} to {}...".format(args.num_workers,
episode_steps_per_rollout))
args.num_workers = episode_steps_per_rollout
if args.rollout_idx >= episode_steps_per_rollout or args.rollout_idx >= args.num_workers:
err_msg_format = "Exiting excess worker..."
err_msg_format += "(rollout_idx[{}] >= num_workers[{}] or num_episodes_between_training[{}])"
logger.info(err_msg_format.format(args.rollout_idx, args.num_workers, episode_steps_per_rollout))
# Close the down the job
utils.cancel_simulation_job()
memory_backend_params = DeepRacerRedisPubSubMemoryBackendParameters(redis_address=redis_ip,
redis_port=6379,
run_type=str(RunType.ROLLOUT_WORKER),
channel=args.s3_prefix,
num_workers=args.num_workers,
rollout_idx=args.rollout_idx)
graph_manager.memory_backend_params = memory_backend_params
checkpoint_dict = {'agent': checkpoint}
ds_params_instance = S3BotoDataStoreParameters(checkpoint_dict=checkpoint_dict)
graph_manager.data_store = S3BotoDataStore(ds_params_instance, graph_manager)
task_parameters = TaskParameters()
task_parameters.checkpoint_restore_path = args.checkpoint_dir
rollout_worker(
graph_manager=graph_manager,
num_workers=args.num_workers,
rollout_idx=args.rollout_idx,
task_parameters=task_parameters,
simtrace_video_s3_writers=simtrace_video_s3_writers,
pause_physics=pause_physics,
unpause_physics=unpause_physics
)