def main()

in scripts/aws_launcher.py [0:0]


def main():
    args = parse_args()

    cf = configparser.ConfigParser()
    cf.read(args.credentials)

    warnings.filterwarnings(
        "ignore", category=ResourceWarning, message="unclosed.*<ssl.SSLSocket.*>"
    )

    regions = args.regions.split(",")
    instance_ids = args.instances.split(",")
    ssh_key_files = args.ssh_key_file.split(",")

    instances = []
    if len(regions) > 1:
        print("Multiple regions detected")

        assert len(instance_ids) == len(
            ssh_key_files
        ), "{} instance ids are provided, but {} SSH keys found.".format(
            len(instance_ids), len(ssh_key_files)
        )

        assert len(instance_ids) == len(
            regions
        ), "{} instance ids are provided, but {} regions found.".format(
            len(instance_ids), len(regions)
        )

        for i, region in enumerate(regions):
            session = boto3.session.Session(
                aws_access_key_id=cf["default"]["aws_access_key_id"],
                aws_secret_access_key=cf["default"]["aws_secret_access_key"],
                region_name=region,
            )
            ec2 = session.resource("ec2")

            instance = get_instances(ec2, [instance_ids[i]])
            instances += instance
    else:
        session = boto3.session.Session(
            aws_access_key_id=cf["default"]["aws_access_key_id"],
            aws_secret_access_key=cf["default"]["aws_secret_access_key"],
            region_name=regions[0],
        )
        ec2 = session.resource("ec2")
        instances = get_instances(ec2, instance_ids)

        assert (
            len(ssh_key_files) == 1
        ), "1 region is detected, but {} SSH keys found.".format(len(ssh_key_files))

        ssh_key_files = [ssh_key_files[0] for _ in range(len(instances))]

    assert len(instance_ids) == len(
        instances
    ), "{} instance ids are provided, but {} found.".format(
        len(instance_ids), len(instances)
    )

    # Only print the public IP addresses of the instances.
    # Then do nothing else and return.
    if args.only_show_instance_ips:
        for instance in instances:
            print(instance.public_ip_address)
        return

    world_size = len(instances)
    print(f"Running world size {world_size} with instances: {instances}")
    master_instance = instances[0]

    # Key: instance id; value: paramiko.SSHClient object.
    client_dict = {}
    for i, instance in enumerate(instances):
        client = connect_to_instance(
            instance, ssh_key_files[i], args.ssh_user, args.http_proxy
        )
        client_dict[instance.id] = client

    assert os.path.exists(
        args.training_script
    ), f"File `{args.training_script}` does not exist"
    file_paths = args.aux_files.split(",") if args.aux_files else []
    for local_path in file_paths:
        assert os.path.exists(local_path), f"File `{local_path}` does not exist"

    remote_dir = f"aws-launcher-tmp-{uuid.uuid1()}"
    script_basename = os.path.basename(args.training_script)
    remote_script = os.path.join(remote_dir, script_basename)

    # Upload files to all instances concurrently.
    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as uploaders:
        for instance_id, client in client_dict.items():
            run_command(instance_id, client, f"mkdir -p {remote_dir}")
            uploaders.submit(
                upload_file, instance_id, client, args.training_script, remote_script
            )
            for local_path in file_paths:
                uploaders.submit(
                    upload_file,
                    instance_id,
                    client,
                    local_path,
                    os.path.join(remote_dir, os.path.basename(local_path)),
                )
    for instance_id, client in client_dict.items():
        run_command(instance_id, client, f"chmod +x {remote_script}")
        run_command(instance_id, client, f"ls -al {remote_dir}")

    environment = {
        "WORLD_SIZE": str(world_size),
        "RENDEZVOUS": "env://",
        "MASTER_ADDR": master_instance.private_ip_address,
        "MASTER_PORT": str(args.master_port),
    }

    with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as executor:
        rank = 0
        for instance_id, client in client_dict.items():
            environment["RANK"] = str(rank)
            # TODO: Although paramiko.SSHClient.exec_command() can accept
            # an argument `environment`, it seems not to take effect in
            # practice. It might because "Servers may silently reject
            # some environment variables" according to paramiko document.
            # As a workaround, here all environment variables are explicitly
            # exported.
            environment_cmd = "; ".join(
                [f"export {key}={value}" for (key, value) in environment.items()]
            )
            prepare_cmd = f"{args.prepare_cmd}; " if args.prepare_cmd else ""
            cmd = "{}; {} {} {} {}".format(
                environment_cmd,
                f"cd {remote_dir} ;",
                prepare_cmd,
                f"./{script_basename}",
                " ".join(args.training_script_args),
            )
            print(f"Run command: {cmd}")
            executor.submit(run_command, instance_id, client, cmd, environment)
            rank += 1

    # Cleanup temp dir.
    for instance_id, client in client_dict.items():
        run_command(instance_id, client, f"rm -rf {remote_dir}")
        client.close()