def get_preset_kwargs()

in local_gemma/modeling_local_gemma_2.py [0:0]


    def get_preset_kwargs(pretrained_model_name_or_path: str, preset: str, device: str, trust_remote_code: bool = False, token: str = None) -> Dict:
        if preset not in PRESET_MAPPING:
            raise ValueError(f"Got invalid `preset` {preset}. Ensure `preset` is one of: {list(PRESET_MAPPING.keys())}")

        if preset == "auto":
            preset, _ = infer_memory_requirements(
                pretrained_model_name_or_path, device, trust_remote_code=trust_remote_code, token=token
            )
            logger.info(f"Detected device {device} and defaulting to {preset} preset.")

        preset_kwargs = PRESET_MAPPING[preset]

        if preset == "speed" and device != "cuda":
            # disable torch compile on non-cuda devices since it's not compatible
            preset_kwargs["torch_compile"] = False

        if preset in ["memory", "memory_extreme"]:
            if device == "cuda" and not is_bitsandbytes_available():
                raise ImportError(
                    f"The {preset} preset on CUDA requires the `bitsandbytes` package. Please install bitsandbytes through: "
                    "`pip install --upgrade bitsandbytes`."
                )
            elif device != "cuda" and not is_quanto_available():
                raise ImportError(
                    f"The {preset} preset on {device} requires the `quanto` package. Please install quanto through: "
                    "`pip install --upgrade quanto`."
                )

        if preset == "memory_extreme":
            if not is_accelerate_available():
                raise ImportError(
                    f"The `memory_extreme` preset requires the `accelerate` package. Please install accelerate through: "
                    "`pip install --upgrade accelerate`."
                )

        return preset_kwargs