optimum/tpu/cli.py (76 lines of code) (raw):
import importlib.util
import os
import shutil
import subprocess
import sys
from pathlib import Path
import click
import typer
TORCH_VER = "2.5.1"
JETSTREAM_PT_VER = "jetstream-v0.2.4"
DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps")
app = typer.Typer()
def _check_module(module_name: str):
spec = importlib.util.find_spec(module_name)
return spec is not None
def _run(cmd: str):
split_cmd = cmd.split()
subprocess.check_call(split_cmd)
def _install_torch_cpu():
# install torch CPU version to avoid installing CUDA dependencies
_run(sys.executable + f" -m pip install torch=={TORCH_VER} --index-url https://download.pytorch.org/whl/cpu")
@app.command()
def install_pytorch_xla(
force: bool = False,
):
"""
Installs PyTorch XLA with TPU support.
Args:
force (bool): When set, force reinstalling even if Pytorch XLA is already installed.
"""
if not force and _check_module("torch") and _check_module("torch_xla"):
typer.confirm(
"PyTorch XLA is already installed. Do you want to reinstall it?",
default=False,
abort=True,
)
_install_torch_cpu()
_run(
sys.executable
+ f" -m pip install torch-xla[tpu]=={TORCH_VER} -f https://storage.googleapis.com/libtpu-releases/index.html"
)
click.echo()
click.echo(click.style("PyTorch XLA has been installed.", bold=True))
@app.command()
def install_jetstream_pytorch(
deps_path: str = DEFAULT_DEPS_PATH,
yes: bool = False,
):
"""
Installs Jetstream Pytorch with TPU support.
Args:
deps_path (str): Path where Jetstream Pytorch dependencies will be installed.
yes (bool): When set, proceed installing without asking questions.
"""
if not _check_module("torch"):
_install_torch_cpu()
if not yes and _check_module("jetstream_pt") and _check_module("torch_xla2"):
typer.confirm(
"Jetstream Pytorch is already installed. Do you want to reinstall it?",
default=False,
abort=True,
)
jetstream_repo_dir = os.path.join(deps_path, "jetstream-pytorch")
if not yes and os.path.exists(jetstream_repo_dir):
typer.confirm(
f"Directory {jetstream_repo_dir} already exists. Do you want to delete it and reinstall Jetstream Pytorch?",
default=False,
abort=True,
)
shutil.rmtree(jetstream_repo_dir, ignore_errors=True)
# Create the directory if it does not exist
os.makedirs(deps_path, exist_ok=True)
# Clone and install Jetstream Pytorch
os.chdir(deps_path)
_run("git clone https://github.com/google/jetstream-pytorch.git")
os.chdir("jetstream-pytorch")
_run(f"git checkout {JETSTREAM_PT_VER}")
_run("git submodule update --init --recursive")
# We cannot install in a temporary directory because the directory should not be deleted after the script finishes,
# because it will install its dependendencies from that directory.
_run(sys.executable + " -m pip install -e .")
_run(
sys.executable
+ f" -m pip install torch_xla[pallas]=={TORCH_VER} "
+ " -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html"
+ " -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html"
+ " -f https://storage.googleapis.com/libtpu-releases/index.html"
)
# Install PyTorch XLA pallas
click.echo()
click.echo(click.style("Jetstream Pytorch has been installed.", bold=True))
if __name__ == "__main__":
sys.exit(app())