def get_gpu_device_ids()

in optimum_benchmark/system_utils.py [0:0]


def get_gpu_device_ids() -> str:
    if is_nvidia_system():
        if os.environ.get("NVIDIA_VISIBLE_DEVICES", None) is not None:
            device_ids = os.environ["NVIDIA_VISIBLE_DEVICES"]
        elif os.environ.get("CUDA_VISIBLE_DEVICES", None) is not None:
            device_ids = os.environ["CUDA_VISIBLE_DEVICES"]
        else:
            if not is_pynvml_available():
                raise ValueError(
                    "The library PyNVML is required to get GPU device ids, but is not installed. "
                    "Please install the official and NVIDIA maintained PyNVML library through `pip install nvidia-ml-py`."
                )

            pynvml.nvmlInit()
            device_ids = list(range(pynvml.nvmlDeviceGetCount()))
            device_ids = ",".join(str(i) for i in device_ids)
            pynvml.nvmlShutdown()
    elif is_rocm_system():
        if os.environ.get("ROCR_VISIBLE_DEVICES", None) is not None:
            device_ids = os.environ["ROCR_VISIBLE_DEVICES"]
        elif os.environ.get("CUDA_VISIBLE_DEVICES", None) is not None:
            device_ids = os.environ["CUDA_VISIBLE_DEVICES"]
        else:
            if not is_amdsmi_available() or not is_pyrsmi_available():
                raise ValueError(
                    "Either the library AMD SMI or PyRSMI is required to get GPU device ids, but neither is installed. "
                    "Please install the official and AMD maintained AMD SMI library from https://github.com/ROCm/amdsmi "
                    "or PyRSMI library from https://github.com/ROCm/pyrsmi."
                )

            if is_pyrsmi_available():
                rocml.smi_initialize()
                device_ids = list(range(rocml.smi_get_device_count()))
                device_ids = ",".join(str(i) for i in device_ids)
                rocml.smi_shutdown()

            elif is_amdsmi_available():
                amdsmi.amdsmi_init()
                device_ids = list(range(len(amdsmi.amdsmi_get_processor_handles())))
                device_ids = ",".join(str(i) for i in device_ids)
                amdsmi.amdsmi_shut_down()

    else:
        raise ValueError("Couldn't infer GPU device ids.")

    return device_ids