in scripts/launcher_distributed_kd.py [0:0]
def check_pytorch_version() -> Optional[str]:
"""
Check and return the installed PyTorch version.
This function runs a Python command to import torch and print its version.
Returns:
Optional[str]: The PyTorch version as a string, or None if an error occurred.
Raises:
subprocess.CalledProcessError: If the subprocess command fails.
"""
try:
# Run the command to get the PyTorch version
result = sb.run(
["python", "-c", "import torch; print(torch.__version__)"],
capture_output=True,
text=True,
check=True,
)
# Extract and strip the version string
version = result.stdout.strip()
print(f"Installed PyTorch version: {version}")
return version
except sb.CalledProcessError as e:
print(f"Error occurred while checking PyTorch version: {e}")
print(f"Error output: {e.stderr}")
return None
except Exception as e:
print(f"Unexpected error occurred: {e}")
return None