in scripts/eval_jat.py [0:0]
def get_default_device() -> torch.device:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")