def _get_current_device()

in src/sal/models/skywork_o1_prm/modeling_base.py [0:0]


    def _get_current_device(cls):
        r"""
        Get the current device. For GPU, we return the local process index using the `accelerate.PartialState`
        object to handle corner cases when running scripts in distributed environments.

        Returns:
            current_device (`Union[int, str]`):
                The current device.
        """
        state = PartialState()
        return state.local_process_index if torch.cuda.is_available() else "cpu"