in local_gemma/utils/config.py [0:0]
def infer_dtype(device: str, dtype_str: Optional[str] = None) -> torch.dtype:
if dtype_str is None:
if is_torch_bf16_available_on_device(device):
return torch.bfloat16
else:
return torch.float16
dtype = DTYPE_MAP.get(dtype_str, None)
if dtype is None:
raise ValueError(f"Unknown dtype: {dtype_str}. Must be one of {DTYPE_MAP.keys()}")
return dtype