def get_max_memory()

in src/accelerate/utils/modeling.py [0:0]


def get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None):
    """
    Get the maximum memory available if nothing is passed, converts string to int otherwise.
    """
    import psutil

    if max_memory is None:
        max_memory = {}
        # Make sure CUDA is initialized on each GPU to have the right memory info.
        if is_npu_available():
            for i in range(torch.npu.device_count()):
                try:
                    _ = torch.tensor(0, device=torch.device("npu", i))
                    max_memory[i] = torch.npu.mem_get_info(i)[0]
                except Exception:
                    logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
                    continue
        elif is_mlu_available():
            for i in range(torch.mlu.device_count()):
                try:
                    _ = torch.tensor(0, device=torch.device("mlu", i))
                    max_memory[i] = torch.mlu.mem_get_info(i)[0]
                except Exception:
                    logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
                    continue
        elif is_sdaa_available():
            for i in range(torch.sdaa.device_count()):
                try:
                    _ = torch.tensor(0, device=torch.device("sdaa", i))
                    max_memory[i] = torch.sdaa.mem_get_info(i)[0]
                except Exception:
                    logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
                    continue
        elif is_musa_available():
            for i in range(torch.musa.device_count()):
                try:
                    _ = torch.tensor(0, device=torch.device("musa", i))
                    max_memory[i] = torch.musa.mem_get_info(i)[0]
                except Exception:
                    logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
                    continue
        elif is_xpu_available():
            for i in range(torch.xpu.device_count()):
                try:
                    _ = torch.tensor(0, device=torch.device("xpu", i))
                    max_memory[i] = get_xpu_available_memory(i)
                except Exception:
                    logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
                    continue
        elif is_hpu_available():
            for i in range(torch.hpu.device_count()):
                try:
                    _ = torch.tensor(0, device=torch.device("hpu", i))
                    max_memory[i] = torch.hpu.mem_get_info(i)[0]
                except Exception:
                    logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
                    continue
        else:
            for i in range(torch.cuda.device_count()):
                try:
                    _ = torch.tensor([0], device=i)
                    max_memory[i] = torch.cuda.mem_get_info(i)[0]
                except Exception:
                    logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
                    continue
        # allocate everything in the mps device as the RAM is shared
        if is_mps_available():
            max_memory["mps"] = psutil.virtual_memory().available
        else:
            max_memory["cpu"] = psutil.virtual_memory().available
        return max_memory

    for key in max_memory:
        if isinstance(max_memory[key], str):
            max_memory[key] = convert_file_size_to_int(max_memory[key])

    # Need to sort the device by type to make sure that we allocate the gpu first.
    # As gpu/npu/xpu are represented by int, we need to sort them first.
    gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)]
    gpu_devices.sort()
    # check if gpu/npu/xpu devices are available and if not, throw a warning
    if is_npu_available():
        num_devices = torch.npu.device_count()
    elif is_mlu_available():
        num_devices = torch.mlu.device_count()
    elif is_sdaa_available():
        num_devices = torch.sdaa.device_count()
    elif is_musa_available():
        num_devices = torch.musa.device_count()
    elif is_xpu_available():
        num_devices = torch.xpu.device_count()
    elif is_hpu_available():
        num_devices = torch.hpu.device_count()
    else:
        num_devices = torch.cuda.device_count()
    for device in gpu_devices:
        if device >= num_devices or device < 0:
            logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}")
    # Add the other devices in the preset order if they are available
    all_devices = gpu_devices + [k for k in ["mps", "cpu", "disk"] if k in max_memory.keys()]
    # Raise an error if a device is not recognized
    for k in max_memory.keys():
        if k not in all_devices:
            raise ValueError(
                f"Device {k} is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'"
            )
    max_memory = {k: max_memory[k] for k in all_devices}

    return max_memory