in local_gemma/modeling_local_gemma_2.py [0:0]
def get_preset_kwargs(pretrained_model_name_or_path: str, preset: str, device: str, trust_remote_code: bool = False, token: str = None) -> Dict:
if preset not in PRESET_MAPPING:
raise ValueError(f"Got invalid `preset` {preset}. Ensure `preset` is one of: {list(PRESET_MAPPING.keys())}")
if preset == "auto":
preset, _ = infer_memory_requirements(
pretrained_model_name_or_path, device, trust_remote_code=trust_remote_code, token=token
)
logger.info(f"Detected device {device} and defaulting to {preset} preset.")
preset_kwargs = PRESET_MAPPING[preset]
if preset == "speed" and device != "cuda":
# disable torch compile on non-cuda devices since it's not compatible
preset_kwargs["torch_compile"] = False
if preset in ["memory", "memory_extreme"]:
if device == "cuda" and not is_bitsandbytes_available():
raise ImportError(
f"The {preset} preset on CUDA requires the `bitsandbytes` package. Please install bitsandbytes through: "
"`pip install --upgrade bitsandbytes`."
)
elif device != "cuda" and not is_quanto_available():
raise ImportError(
f"The {preset} preset on {device} requires the `quanto` package. Please install quanto through: "
"`pip install --upgrade quanto`."
)
if preset == "memory_extreme":
if not is_accelerate_available():
raise ImportError(
f"The `memory_extreme` preset requires the `accelerate` package. Please install accelerate through: "
"`pip install --upgrade accelerate`."
)
return preset_kwargs