def get_gpu_vram_mb()

in optimum_benchmark/system_utils.py [0:0]


def get_gpu_vram_mb() -> List[int]:
    if is_nvidia_system():
        if not is_pynvml_available():
            raise ValueError(
                "The library PyNVML is required to get GPU VRAM, but is not installed. "
                "Please install the official and NVIDIA maintained PyNVML library through `pip install nvidia-ml-py`."
            )

        pynvml.nvmlInit()
        vrams = [
            pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(i)).total
            for i in range(pynvml.nvmlDeviceGetCount())
        ]
        pynvml.nvmlShutdown()

    elif is_rocm_system():
        if not is_amdsmi_available() and not is_pyrsmi_available():
            raise ValueError(
                "Either the library AMD SMI or PyRSMI is required to get GPU VRAM, but neither is installed. "
                "Please install the official and AMD maintained AMD SMI library from https://github.com/ROCm/amdsmi "
                "or PyRSMI library from https://github.com/ROCm/pyrsmi."
            )

        if is_amdsmi_available():
            amdsmi.amdsmi_init()
            vrams = [
                amdsmi.amdsmi_get_gpu_memory_total(processor_handles, mem_type=amdsmi.AmdSmiMemoryType.VRAM)
                for processor_handles in amdsmi.amdsmi_get_processor_handles()
            ]
            amdsmi.amdsmi_shut_down()

        elif is_pyrsmi_available():
            rocml.smi_initialize()
            vrams = [rocml.smi_get_device_memory_total(i) for i in range(rocml.smi_get_device_count())]
            rocml.smi_shutdown()

    else:
        raise ValueError("No NVIDIA or ROCm GPUs found.")

    return sum(vrams)