in src/sal/models/skywork_o1_prm/modeling_base.py [0:0]
def _get_current_device(cls):
r"""
Get the current device. For GPU, we return the local process index using the `accelerate.PartialState`
object to handle corner cases when running scripts in distributed environments.
Returns:
current_device (`Union[int, str]`):
The current device.
"""
state = PartialState()
return state.local_process_index if torch.cuda.is_available() else "cpu"