def get_nvidia_driver_install_cmd()

in dags/pytorch_xla/configs/pytorchxla_torchbench_config.py [0:0]


  def get_nvidia_driver_install_cmd(driver_version: str) -> str:
    nvidia_driver_install = (
        "curl -s https://raw.githubusercontent.com/GoogleCloudPlatform/compute-gpu-installation/main/linux/install_gpu_driver.py --output install_gpu_driver.py",
        # Command `apt update/upgrade` receives 403 bad gateway error when connecting to the google apt repo.
        # This can be a transient error. We use the following command to fix the issue for now.
        # TODO(piz): remove the following statement for temporary fix once the `apt update/upgrade` is removed or updated.
        "sed -i '/^\\s*run(\"apt update\")/,/^\\s*return True/ s/^/# /'  install_gpu_driver.py",
        f"sed -i 's/^\\(DRIVER_VERSION = \\).*/\\1\"{driver_version}\"/' install_gpu_driver.py",
        "sudo python3 install_gpu_driver.py --force",
        "sudo nvidia-smi",
    )
    return nvidia_driver_install