build_aarch64_wheel.py (485 lines of code) (raw):

#!/usr/bin/env python3 # This script is for building AARCH64 wheels using AWS EC2 instances. # To generate binaries for the release follow these steps: # 1. Update mappings for each of the Domain Libraries by adding new row to a table like this: "v1.11.0": ("0.11.0", "rc1"), # 2. Run script with following arguments for each of the supported python versions and specify required RC tag for example: v1.11.0-rc3: # build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.7 --branch <RCtag> import boto3 import os import subprocess import sys import time from typing import Dict, List, Optional, Tuple, Union # AMI images for us-east-1, change the following based on your ~/.aws/config os_amis = { 'ubuntu18_04': "ami-0f2b111fdc1647918", # login_name: ubuntu 'ubuntu20_04': "ami-0ea142bd244023692", # login_name: ubuntu 'redhat8': "ami-0698b90665a2ddcf1", # login_name: ec2-user } ubuntu18_04_ami = os_amis['ubuntu18_04'] def compute_keyfile_path(key_name: Optional[str] = None) -> Tuple[str, str]: if key_name is None: key_name = os.getenv("AWS_KEY_NAME") if key_name is None: return os.getenv("SSH_KEY_PATH", ""), "" homedir_path = os.path.expanduser("~") default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem") return os.getenv("SSH_KEY_PATH", default_path), key_name ec2 = boto3.resource("ec2") def ec2_get_instances(filter_name, filter_value): return ec2.instances.filter(Filters=[{'Name': filter_name, 'Values': [filter_value]}]) def ec2_instances_of_type(instance_type='t4g.2xlarge'): return ec2_get_instances('instance-type', instance_type) def ec2_instances_by_id(instance_id): rc = list(ec2_get_instances('instance-id', instance_id)) return rc[0] if len(rc) > 0 else None def start_instance(key_name, ami=ubuntu18_04_ami, instance_type='t4g.2xlarge'): inst = ec2.create_instances(ImageId=ami, InstanceType=instance_type, SecurityGroups=['ssh-allworld'], KeyName=key_name, MinCount=1, MaxCount=1, BlockDeviceMappings=[ { 'DeviceName': '/dev/sda1', 'Ebs': { 'VolumeSize': 50, 'VolumeType': 'standard' } } ])[0] print(f'Create instance {inst.id}') inst.wait_until_running() running_inst = ec2_instances_by_id(inst.id) print(f'Instance started at {running_inst.public_dns_name}') return running_inst class RemoteHost: addr: str keyfile_path: str login_name: str container_id: Optional[str] = None ami: Optional[str] = None def __init__(self, addr: str, keyfile_path: str, login_name: str = 'ubuntu'): self.addr = addr self.keyfile_path = keyfile_path self.login_name = login_name def _gen_ssh_prefix(self) -> List[str]: return ["ssh", "-o", "StrictHostKeyChecking=no", "-i", self.keyfile_path, f"{self.login_name}@{self.addr}", "--"] @staticmethod def _split_cmd(args: Union[str, List[str]]) -> List[str]: return args.split() if isinstance(args, str) else args def run_ssh_cmd(self, args: Union[str, List[str]]) -> None: subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args)) def check_ssh_output(self, args: Union[str, List[str]]) -> str: return subprocess.check_output(self._gen_ssh_prefix() + self._split_cmd(args)).decode("utf-8") def scp_upload_file(self, local_file: str, remote_file: str) -> None: subprocess.check_call(["scp", "-i", self.keyfile_path, local_file, f"{self.login_name}@{self.addr}:{remote_file}"]) def scp_download_file(self, remote_file: str, local_file: Optional[str] = None) -> None: if local_file is None: local_file = "." subprocess.check_call(["scp", "-i", self.keyfile_path, f"{self.login_name}@{self.addr}:{remote_file}", local_file]) def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None: self.run_ssh_cmd("sudo apt-get install -y docker.io") self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}") self.run_ssh_cmd("sudo service docker start") self.run_ssh_cmd(f"docker pull {image}") self.container_id = self.check_ssh_output(f"docker run -t -d -w /root {image}").strip() def using_docker(self) -> bool: return self.container_id is not None def run_cmd(self, args: Union[str, List[str]]) -> None: if not self.using_docker(): return self.run_ssh_cmd(args) assert self.container_id is not None docker_cmd = self._gen_ssh_prefix() + ['docker', 'exec', '-i', self.container_id, 'bash'] p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE) p.communicate(input=" ".join(["source .bashrc;"] + self._split_cmd(args)).encode("utf-8")) rc = p.wait() if rc != 0: raise subprocess.CalledProcessError(rc, docker_cmd) def check_output(self, args: Union[str, List[str]]) -> str: if not self.using_docker(): return self.check_ssh_output(args) assert self.container_id is not None docker_cmd = self._gen_ssh_prefix() + ['docker', 'exec', '-i', self.container_id, 'bash'] p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE) (out, err) = p.communicate(input=" ".join(["source .bashrc;"] + self._split_cmd(args)).encode("utf-8")) rc = p.wait() if rc != 0: raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err) return out.decode("utf-8") def upload_file(self, local_file: str, remote_file: str) -> None: if not self.using_docker(): return self.scp_upload_file(local_file, remote_file) tmp_file = os.path.join("/tmp", os.path.basename(local_file)) self.scp_upload_file(local_file, tmp_file) self.run_ssh_cmd(["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]) self.run_ssh_cmd(["rm", tmp_file]) def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None: if not self.using_docker(): return self.scp_download_file(remote_file, local_file) tmp_file = os.path.join("/tmp", os.path.basename(remote_file)) self.run_ssh_cmd(["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]) self.scp_download_file(tmp_file, local_file) self.run_ssh_cmd(["rm", tmp_file]) def download_wheel(self, remote_file: str, local_file: Optional[str] = None) -> None: if self.using_docker() and local_file is None: basename = os.path.basename(remote_file) local_file = basename.replace("-linux_aarch64.whl", "-manylinux2014_aarch64.whl") self.download_file(remote_file, local_file) def list_dir(self, path: str) -> List[str]: return self.check_output(["ls", "-1", path]).split("\n") def wait_for_connection(addr, port, timeout=5, attempt_cnt=5): import socket for i in range(attempt_cnt): try: with socket.create_connection((addr, port), timeout=timeout): return except (ConnectionRefusedError, socket.timeout): if i == attempt_cnt - 1: raise time.sleep(timeout) def update_apt_repo(host: RemoteHost) -> None: time.sleep(5) host.run_cmd("sudo systemctl stop apt-daily.service || true") host.run_cmd("sudo systemctl stop unattended-upgrades.service || true") host.run_cmd("while systemctl is-active --quiet apt-daily.service; do sleep 1; done") host.run_cmd("while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done") host.run_cmd("sudo apt-get update") time.sleep(3) host.run_cmd("sudo apt-get update") def install_condaforge(host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh") -> None: print('Install conda-forge') host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}") host.run_cmd(f"sh -f {os.path.basename(suffix)} -b") host.run_cmd(f"rm -f {os.path.basename(suffix)}") if host.using_docker(): host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc") else: host.run_cmd(['sed', '-i', '\'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH\'', '.bashrc']) def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None: if python_version == "3.6": # Python-3.6 EOLed and not compatible with conda-4.11 install_condaforge(host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh") host.run_cmd(f"conda install -y python={python_version} numpy pyyaml") else: install_condaforge(host) # Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer host.run_cmd(f"conda install -y python={python_version} numpy pyyaml setuptools=59.5.0") def build_OpenBLAS(host: RemoteHost, git_clone_flags: str = "") -> None: print('Building OpenBLAS') host.run_cmd(f"git clone https://github.com/xianyi/OpenBLAS -b v0.3.19 {git_clone_flags}") make_flags = "NUM_THREADS=64 USE_OPENMP=1 NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=ARMV8" host.run_cmd(f"pushd OpenBLAS; make {make_flags} -j8; sudo make {make_flags} install; popd; rm -rf OpenBLAS") def build_FFTW(host: RemoteHost, git_clone_flags: str = "") -> None: print("Building FFTW3") host.run_cmd("sudo apt-get install -y ocaml ocamlbuild autoconf automake indent libtool fig2dev texinfo") # TODO: fix a version to build # TODO: consider adding flags --host=arm-linux-gnueabi --enable-single --enable-neon CC=arm-linux-gnueabi-gcc -march=armv7-a -mfloat-abi=softfp host.run_cmd(f"git clone https://github.com/FFTW/fftw3 {git_clone_flags}") host.run_cmd("pushd fftw3; sh bootstrap.sh; make -j8; sudo make install; popd") def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None: host.run_cmd("pip3 install auditwheel") host.run_cmd("conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf") from tempfile import NamedTemporaryFile with NamedTemporaryFile() as tmp: tmp.write(embed_library_script.encode('utf-8')) tmp.flush() host.upload_file(tmp.name, "embed_library.py") print('Embedding libgomp into wheel') if host.using_docker(): host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag") else: host.run_cmd(f"python3 embed_library.py {wheel_name}") def checkout_repo(host: RemoteHost, *, branch: str = "master", url: str, git_clone_flags: str, mapping: Dict[str, Tuple[str, str]]) -> Optional[str]: for prefix in mapping: if not branch.startswith(prefix): continue tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}" host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}") return mapping[prefix][0] host.run_cmd(f"git clone {url} {git_clone_flags}") return None def build_torchvision(host: RemoteHost, *, branch: str = "master", use_conda: bool = True, git_clone_flags: str) -> str: print('Checking out TorchVision repo') build_version = checkout_repo(host, branch=branch, url="https://github.com/pytorch/vision", git_clone_flags=git_clone_flags, mapping={ "v1.7.1": ("0.8.2", "rc2"), "v1.8.0": ("0.9.0", "rc3"), "v1.8.1": ("0.9.1", "rc1"), "v1.9.0": ("0.10.0", "rc1"), "v1.10.0": ("0.11.1", "rc1"), "v1.10.1": ("0.11.2", "rc1"), "v1.10.2": ("0.11.3", "rc1"), "v1.11.0": ("0.12.0", "rc1"), }) print('Building TorchVision wheel') build_vars = "" if branch == 'nightly': version = host.check_output(["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]).strip() if len(version) == 0: # In older revisions, version was embedded in setup.py version = host.check_output(["grep", "\"version = '\"", "vision/setup.py"]).strip().split("'")[1][:-2] build_date = host.check_output("cd pytorch ; git log --pretty=format:%s -1").strip().split()[0].replace("-", "") build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: build_vars += f"BUILD_VERSION={build_version}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" host.run_cmd(f"cd vision; {build_vars} python3 setup.py bdist_wheel") vision_wheel_name = host.list_dir("vision/dist")[0] embed_libgomp(host, use_conda, os.path.join('vision', 'dist', vision_wheel_name)) print('Copying TorchVision wheel') host.download_wheel(os.path.join('vision', 'dist', vision_wheel_name)) print("Delete vision checkout") host.run_cmd("rm -rf vision") return vision_wheel_name def build_torchtext(host: RemoteHost, *, branch: str = "master", use_conda: bool = True, git_clone_flags: str = "") -> str: print('Checking out TorchText repo') git_clone_flags += " --recurse-submodules" build_version = checkout_repo(host, branch=branch, url="https://github.com/pytorch/text", git_clone_flags=git_clone_flags, mapping={ "v1.9.0": ("0.10.0", "rc1"), "v1.10.0": ("0.11.0", "rc2"), "v1.10.1": ("0.11.1", "rc1"), "v1.10.2": ("0.11.2", "rc1"), "v1.11.0": ("0.12.0", "rc1"), }) print('Building TorchText wheel') build_vars = "" if branch == 'nightly': version = host.check_output(["if [ -f text/version.txt ]; then cat text/version.txt; fi"]).strip() build_date = host.check_output("cd pytorch ; git log --pretty=format:%s -1").strip().split()[0].replace("-", "") build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: build_vars += f"BUILD_VERSION={build_version}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" host.run_cmd(f"cd text; {build_vars} python3 setup.py bdist_wheel") wheel_name = host.list_dir("text/dist")[0] embed_libgomp(host, use_conda, os.path.join('text', 'dist', wheel_name)) print('Copying TorchText wheel') host.download_wheel(os.path.join('text', 'dist', wheel_name)) return wheel_name def build_torchaudio(host: RemoteHost, *, branch: str = "master", use_conda: bool = True, git_clone_flags: str = "") -> str: print('Checking out TorchAudio repo') git_clone_flags += " --recurse-submodules" build_version = checkout_repo(host, branch=branch, url="https://github.com/pytorch/audio", git_clone_flags=git_clone_flags, mapping={ "v1.9.0": ("0.9.0", "rc2"), "v1.10.0": ("0.10.0", "rc5"), "v1.10.1": ("0.10.1", "rc1"), "v1.10.2": ("0.10.2", "rc1"), "v1.11.0": ("0.11.0", "rc1"), }) print('Building TorchAudio wheel') build_vars = "" if branch == 'nightly': version = host.check_output(["grep", "\"version = '\"", "audio/setup.py"]).strip().split("'")[1][:-2] build_date = host.check_output("cd pytorch ; git log --pretty=format:%s -1").strip().split()[0].replace("-", "") build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: build_vars += f"BUILD_VERSION={build_version}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" host.run_cmd(f"cd audio; {build_vars} python3 setup.py bdist_wheel") wheel_name = host.list_dir("audio/dist")[0] embed_libgomp(host, use_conda, os.path.join('audio', 'dist', wheel_name)) print('Copying TorchAudio wheel') host.download_wheel(os.path.join('audio', 'dist', wheel_name)) return wheel_name def configure_system(host: RemoteHost, *, compiler="gcc-8", use_conda=True, python_version="3.8") -> None: if use_conda: install_condaforge_python(host, python_version) print('Configuring the system') if not host.using_docker(): update_apt_repo(host) host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip") else: host.run_cmd("yum install -y sudo") host.run_cmd("conda install -y ninja") if not use_conda: host.run_cmd("sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip") host.run_cmd("pip3 install dataclasses typing-extensions") # Install and switch to gcc-8 on Ubuntu-18.04 if not host.using_docker() and host.ami == ubuntu18_04_ami and compiler == 'gcc-8': host.run_cmd("sudo apt-get install -y g++-8 gfortran-8") host.run_cmd("sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100") host.run_cmd("sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 100") host.run_cmd("sudo update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-8 100") if not use_conda: print("Installing Cython + numpy from PyPy") host.run_cmd("sudo pip3 install Cython") host.run_cmd("sudo pip3 install numpy") def start_build(host: RemoteHost, *, branch="master", compiler="gcc-8", use_conda=True, python_version="3.8", shallow_clone=True) -> Tuple[str, str]: git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else "" if host.using_docker() and not use_conda: print("Auto-selecting conda option for docker images") use_conda = True configure_system(host, compiler=compiler, use_conda=use_conda, python_version=python_version) build_OpenBLAS(host, git_clone_flags) # build_FFTW(host, git_clone_flags) if host.using_docker(): print("Move libgfortant.a into a standard location") # HACK: pypa gforntran.a is compiled without PIC, which leads to the following error # libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # Workaround by copying gfortran library from the host host.run_ssh_cmd("sudo apt-get install -y gfortran-8") host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8") host.run_ssh_cmd(["docker", "cp", "/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a", f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/" ]) print('Checking out PyTorch repo') host.run_cmd(f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}") print('Building PyTorch wheel') # Breakpad build fails on aarch64 build_vars = "USE_BREAKPAD=0 " if branch == 'nightly': build_date = host.check_output("cd pytorch ; git log --pretty=format:%s -1").strip().split()[0].replace("-", "") version = host.check_output("cat pytorch/version.txt").strip()[:-2] build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1" if branch.startswith("v1."): build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1:branch.find('-')]} PYTORCH_BUILD_NUMBER=1" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" host.run_cmd(f"cd pytorch ; {build_vars} python3 setup.py bdist_wheel") print("Deleting build folder") host.run_cmd("cd pytorch; rm -rf build") pytorch_wheel_name = host.list_dir("pytorch/dist")[0] embed_libgomp(host, use_conda, os.path.join('pytorch', 'dist', pytorch_wheel_name)) print('Copying the wheel') host.download_wheel(os.path.join('pytorch', 'dist', pytorch_wheel_name)) print('Installing PyTorch wheel') host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}") vision_wheel_name = build_torchvision(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags) build_torchaudio(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags) build_torchtext(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags) return pytorch_wheel_name, vision_wheel_name embed_library_script = """ #!/usr/bin/env python3 from auditwheel.patcher import Patchelf from auditwheel.wheeltools import InWheelCtx from auditwheel.elfutils import elf_file_filter from auditwheel.repair import copylib from auditwheel.lddtree import lddtree from subprocess import check_call import os import shutil import sys from tempfile import TemporaryDirectory def replace_tag(filename): with open(filename, 'r') as f: lines = f.read().split("\\n") for i,line in enumerate(lines): if not line.startswith("Tag: "): continue lines[i] = line.replace("-linux_", "-manylinux2014_") print(f'Updated tag from {line} to {lines[i]}') with open(filename, 'w') as f: f.write("\\n".join(lines)) class AlignedPatchelf(Patchelf): def set_soname(self, file_name: str, new_soname: str) -> None: check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name]) def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None: check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name]) def embed_library(whl_path, lib_soname, update_tag=False): patcher = AlignedPatchelf() out_dir = TemporaryDirectory() whl_name = os.path.basename(whl_path) tmp_whl_name = os.path.join(out_dir.name, whl_name) with InWheelCtx(whl_path) as ctx: torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib') ctx.out_wheel=tmp_whl_name new_lib_path, new_lib_soname = None, None for filename, elf in elf_file_filter(ctx.iter_files()): if not filename.startswith('torch/lib'): continue libtree = lddtree(filename) if lib_soname not in libtree['needed']: continue lib_path = libtree['libs'][lib_soname]['path'] if lib_path is None: print(f"Can't embed {lib_soname} as it could not be found") break if lib_path.startswith(torchlib_path): continue if new_lib_path is None: new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher) patcher.replace_needed(filename, lib_soname, new_lib_soname) print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}') if update_tag: # Add manylinux2014 tag for filename in ctx.iter_files(): if os.path.basename(filename) != 'WHEEL': continue replace_tag(filename) shutil.move(tmp_whl_name, whl_path) if __name__ == '__main__': embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag') """ def run_tests(host: RemoteHost, whl: str, branch='master') -> None: print('Configuring the system') update_apt_repo(host) host.run_cmd("sudo apt-get install -y python3-pip git") host.run_cmd("sudo pip3 install Cython") host.run_cmd("sudo pip3 install numpy") host.upload_file(whl, ".") host.run_cmd(f"sudo pip3 install {whl}") host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'") host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch") host.run_cmd("cd pytorch/test; python3 test_torch.py -v") def get_instance_name(instance) -> Optional[str]: if instance.tags is None: return None for tag in instance.tags: if tag['Key'] == 'Name': return tag['Value'] return None def list_instances(instance_type: str) -> None: print(f"All instances of type {instance_type}") for instance in ec2_instances_of_type(instance_type): print(f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']}") def terminate_instances(instance_type: str) -> None: print(f"Terminating all instances of type {instance_type}") instances = list(ec2_instances_of_type(instance_type)) for instance in instances: print(f"Terminating {instance.id}") instance.terminate() print("Waiting for termination to complete") for instance in instances: instance.wait_until_terminated() def parse_arguments(): from argparse import ArgumentParser parser = ArgumentParser("Builid and test AARCH64 wheels using EC2") parser.add_argument("--key-name", type=str) parser.add_argument("--debug", action="store_true") parser.add_argument("--build-only", action="store_true") parser.add_argument("--test-only", type=str) parser.add_argument("--os", type=str, choices=list(os_amis.keys()), default='ubuntu18_04') parser.add_argument("--python-version", type=str, choices=['3.6', '3.7', '3.8', '3.9', '3.10'], default=None) parser.add_argument("--alloc-instance", action="store_true") parser.add_argument("--list-instances", action="store_true") parser.add_argument("--keep-running", action="store_true") parser.add_argument("--terminate-instances", action="store_true") parser.add_argument("--instance-type", type=str, default="t4g.2xlarge") parser.add_argument("--branch", type=str, default="master") parser.add_argument("--use-docker", action="store_true") parser.add_argument("--compiler", type=str, choices=['gcc-7', 'gcc-8', 'gcc-9', 'clang'], default="gcc-8") parser.add_argument("--use-torch-from-pypi", action="store_true") return parser.parse_args() if __name__ == '__main__': args = parse_arguments() ami = os_amis[args.os] keyfile_path, key_name = compute_keyfile_path(args.key_name) if args.list_instances: list_instances(args.instance_type) sys.exit(0) if args.terminate_instances: terminate_instances(args.instance_type) sys.exit(0) if len(key_name) == 0: raise Exception(""" Cannot start build without key_name, please specify --key-name argument or AWS_KEY_NAME environment variable.""") if len(keyfile_path) == 0 or not os.path.exists(keyfile_path): raise Exception(f""" Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""") # Starting the instance inst = start_instance(key_name, ami=ami) instance_name = f'{args.key_name}-{args.os}' if args.python_version is not None: instance_name += f'-py{args.python_version}' inst.create_tags(DryRun=False, Tags=[{ 'Key': 'Name', 'Value': instance_name, }]) addr = inst.public_dns_name wait_for_connection(addr, 22) host = RemoteHost(addr, keyfile_path) host.ami = ami if args.use_docker: update_apt_repo(host) host.start_docker() if args.test_only: run_tests(host, args.test_only) sys.exit(0) if args.alloc_instance: if args.python_version is None: sys.exit(0) install_condaforge_python(host, args.python_version) sys.exit(0) python_version = args.python_version if args.python_version is not None else '3.8' if args.use_torch_from_pypi: configure_system(host, compiler=args.compiler, python_version=python_version) print("Installing PyTorch wheel") host.run_cmd("pip3 install torch") build_torchvision(host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules") else: start_build(host, branch=args.branch, compiler=args.compiler, python_version=python_version) if not args.keep_running: print(f'Waiting for instance {inst.id} to terminate') inst.terminate() inst.wait_until_terminated()