in scripts/aws_launcher.py [0:0]
def main():
args = parse_args()
cf = configparser.ConfigParser()
cf.read(args.credentials)
warnings.filterwarnings(
"ignore", category=ResourceWarning, message="unclosed.*<ssl.SSLSocket.*>"
)
regions = args.regions.split(",")
instance_ids = args.instances.split(",")
ssh_key_files = args.ssh_key_file.split(",")
instances = []
if len(regions) > 1:
print("Multiple regions detected")
assert len(instance_ids) == len(
ssh_key_files
), "{} instance ids are provided, but {} SSH keys found.".format(
len(instance_ids), len(ssh_key_files)
)
assert len(instance_ids) == len(
regions
), "{} instance ids are provided, but {} regions found.".format(
len(instance_ids), len(regions)
)
for i, region in enumerate(regions):
session = boto3.session.Session(
aws_access_key_id=cf["default"]["aws_access_key_id"],
aws_secret_access_key=cf["default"]["aws_secret_access_key"],
region_name=region,
)
ec2 = session.resource("ec2")
instance = get_instances(ec2, [instance_ids[i]])
instances += instance
else:
session = boto3.session.Session(
aws_access_key_id=cf["default"]["aws_access_key_id"],
aws_secret_access_key=cf["default"]["aws_secret_access_key"],
region_name=regions[0],
)
ec2 = session.resource("ec2")
instances = get_instances(ec2, instance_ids)
assert (
len(ssh_key_files) == 1
), "1 region is detected, but {} SSH keys found.".format(len(ssh_key_files))
ssh_key_files = [ssh_key_files[0] for _ in range(len(instances))]
assert len(instance_ids) == len(
instances
), "{} instance ids are provided, but {} found.".format(
len(instance_ids), len(instances)
)
# Only print the public IP addresses of the instances.
# Then do nothing else and return.
if args.only_show_instance_ips:
for instance in instances:
print(instance.public_ip_address)
return
world_size = len(instances)
print(f"Running world size {world_size} with instances: {instances}")
master_instance = instances[0]
# Key: instance id; value: paramiko.SSHClient object.
client_dict = {}
for i, instance in enumerate(instances):
client = connect_to_instance(
instance, ssh_key_files[i], args.ssh_user, args.http_proxy
)
client_dict[instance.id] = client
assert os.path.exists(
args.training_script
), f"File `{args.training_script}` does not exist"
file_paths = args.aux_files.split(",") if args.aux_files else []
for local_path in file_paths:
assert os.path.exists(local_path), f"File `{local_path}` does not exist"
remote_dir = f"aws-launcher-tmp-{uuid.uuid1()}"
script_basename = os.path.basename(args.training_script)
remote_script = os.path.join(remote_dir, script_basename)
# Upload files to all instances concurrently.
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as uploaders:
for instance_id, client in client_dict.items():
run_command(instance_id, client, f"mkdir -p {remote_dir}")
uploaders.submit(
upload_file, instance_id, client, args.training_script, remote_script
)
for local_path in file_paths:
uploaders.submit(
upload_file,
instance_id,
client,
local_path,
os.path.join(remote_dir, os.path.basename(local_path)),
)
for instance_id, client in client_dict.items():
run_command(instance_id, client, f"chmod +x {remote_script}")
run_command(instance_id, client, f"ls -al {remote_dir}")
environment = {
"WORLD_SIZE": str(world_size),
"RENDEZVOUS": "env://",
"MASTER_ADDR": master_instance.private_ip_address,
"MASTER_PORT": str(args.master_port),
}
with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as executor:
rank = 0
for instance_id, client in client_dict.items():
environment["RANK"] = str(rank)
# TODO: Although paramiko.SSHClient.exec_command() can accept
# an argument `environment`, it seems not to take effect in
# practice. It might because "Servers may silently reject
# some environment variables" according to paramiko document.
# As a workaround, here all environment variables are explicitly
# exported.
environment_cmd = "; ".join(
[f"export {key}={value}" for (key, value) in environment.items()]
)
prepare_cmd = f"{args.prepare_cmd}; " if args.prepare_cmd else ""
cmd = "{}; {} {} {} {}".format(
environment_cmd,
f"cd {remote_dir} ;",
prepare_cmd,
f"./{script_basename}",
" ".join(args.training_script_args),
)
print(f"Run command: {cmd}")
executor.submit(run_command, instance_id, client, cmd, environment)
rank += 1
# Cleanup temp dir.
for instance_id, client in client_dict.items():
run_command(instance_id, client, f"rm -rf {remote_dir}")
client.close()