def infer_memory_requirements()

in local_gemma/utils/config.py [0:0]


def infer_memory_requirements(model_name, device=None, token=None, trust_remote_code=False) -> str:
    config = Gemma2Config.from_pretrained(model_name, token=token, trust_remote_code=trust_remote_code)

    with init_empty_weights():
        model = Gemma2ForCausalLM(config)

    total_size, _ = calculate_maximum_sizes(model)
    device = infer_device(device)

    if device == "cuda":
        total_memory = torch.cuda.get_device_properties(device).total_memory
    else:
        total_memory = psutil.virtual_memory().total

    for preset in DTYPE_MODIFIER.keys():
        dtype_total_size = total_size / DTYPE_MODIFIER[preset]
        inference_requirements = 1.15 * dtype_total_size  # 1.15 allows A10G to run the `exact` preset on the 9b model
        spare_memory = total_memory - inference_requirements

        if inference_requirements < total_memory:
            return preset, spare_memory

    # if the model does not fit fully in the device, return the last preset ('memory_extreme') which will automatically
    # enable CPU offloading so that we can fit any device
    return "memory_extreme", 0