in dags/pytorch_xla/configs/pytorchxla_torchbench_config.py [0:0]
def get_nvidia_driver_install_cmd(driver_version: str) -> str:
nvidia_driver_install = (
"curl -s https://raw.githubusercontent.com/GoogleCloudPlatform/compute-gpu-installation/main/linux/install_gpu_driver.py --output install_gpu_driver.py",
# Command `apt update/upgrade` receives 403 bad gateway error when connecting to the google apt repo.
# This can be a transient error. We use the following command to fix the issue for now.
# TODO(piz): remove the following statement for temporary fix once the `apt update/upgrade` is removed or updated.
"sed -i '/^\\s*run(\"apt update\")/,/^\\s*return True/ s/^/# /' install_gpu_driver.py",
f"sed -i 's/^\\(DRIVER_VERSION = \\).*/\\1\"{driver_version}\"/' install_gpu_driver.py",
"sudo python3 install_gpu_driver.py --force",
"sudo nvidia-smi",
)
return nvidia_driver_install