def infer_dtype()

in local_gemma/utils/config.py [0:0]


def infer_dtype(device: str, dtype_str: Optional[str] = None) -> torch.dtype:
    if dtype_str is None:
        if is_torch_bf16_available_on_device(device):
            return torch.bfloat16
        else:
            return torch.float16
    dtype = DTYPE_MAP.get(dtype_str, None)
    if dtype is None:
        raise ValueError(f"Unknown dtype: {dtype_str}. Must be one of {DTYPE_MAP.keys()}")
    return dtype