optimum/tpu/jetstream_pt_support.py (11 lines of code) (raw):

import os def jetstream_pt_available() -> bool: """Check if the necessary imports to use jetstream_pt are available. """ try: # Jetstream Pytorch is enabled by default, it can be disabled with an ENV variable. jetstream_pt_disabled = os.environ.get("JETSTREAM_PT_DISABLE", False) == "1" if jetstream_pt_disabled: return False # Import torch_xla2 first! import torch_xla2 # noqa: F401, isort:skip import jetstream_pt # noqa: F401 return True except ImportError: return False