in local_gemma/utils/config.py [0:0]
def infer_memory_requirements(model_name, device=None, token=None, trust_remote_code=False) -> str:
config = Gemma2Config.from_pretrained(model_name, token=token, trust_remote_code=trust_remote_code)
with init_empty_weights():
model = Gemma2ForCausalLM(config)
total_size, _ = calculate_maximum_sizes(model)
device = infer_device(device)
if device == "cuda":
total_memory = torch.cuda.get_device_properties(device).total_memory
else:
total_memory = psutil.virtual_memory().total
for preset in DTYPE_MODIFIER.keys():
dtype_total_size = total_size / DTYPE_MODIFIER[preset]
inference_requirements = 1.15 * dtype_total_size # 1.15 allows A10G to run the `exact` preset on the 9b model
spare_memory = total_memory - inference_requirements
if inference_requirements < total_memory:
return preset, spare_memory
# if the model does not fit fully in the device, return the last preset ('memory_extreme') which will automatically
# enable CPU offloading so that we can fit any device
return "memory_extreme", 0