local_gemma/utils/config.py (71 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import psutil
import torch
from typing import Dict, Optional
from accelerate import init_empty_weights
from transformers import Gemma2ForCausalLM, Gemma2Config
from transformers.utils import is_torch_bf16_available_on_device
from accelerate.utils import calculate_maximum_sizes
DTYPE_MODIFIER = {"exact": 2, "speed": 2, "memory": 8, "memory_extreme": 8}
DTYPE_MAP = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
def infer_device(device: Optional[str] = None) -> str:
"""
Infers basic devices available on the system. Prioritizes the most performant device.
"""
if device is not None:
return device
elif torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def infer_dtype(device: str, dtype_str: Optional[str] = None) -> torch.dtype:
if dtype_str is None:
if is_torch_bf16_available_on_device(device):
return torch.bfloat16
else:
return torch.float16
dtype = DTYPE_MAP.get(dtype_str, None)
if dtype is None:
raise ValueError(f"Unknown dtype: {dtype_str}. Must be one of {DTYPE_MAP.keys()}")
return dtype
def get_prompt(mode: str) -> str:
if mode == "chat":
return ""
elif mode == "factual":
return "Please reply to the following requests with short and factual answers.\n\n"
elif mode == "creative":
return (
"Write a response that appropriately completes the request. Be descriptive, fluid, and follow the context "
"provided.\n\n"
)
else:
raise ValueError(f"Unknown mode: {mode}")
def get_generation_kwargs(mode: str) -> Dict:
generation_kwargs = {"do_sample": True}
if mode == "chat":
generation_kwargs["temperature"] = 0.7
elif mode == "factual":
generation_kwargs["temperature"] = 0.3
generation_kwargs["repetition_penalty"] = 1.2
elif mode == "creative":
generation_kwargs["min_p"] = 0.08
generation_kwargs["temperature"] = 1.5
else:
raise ValueError(f"Unknown mode: {mode}")
return generation_kwargs
def infer_memory_requirements(model_name, device=None, token=None, trust_remote_code=False) -> str:
config = Gemma2Config.from_pretrained(model_name, token=token, trust_remote_code=trust_remote_code)
with init_empty_weights():
model = Gemma2ForCausalLM(config)
total_size, _ = calculate_maximum_sizes(model)
device = infer_device(device)
if device == "cuda":
total_memory = torch.cuda.get_device_properties(device).total_memory
else:
total_memory = psutil.virtual_memory().total
for preset in DTYPE_MODIFIER.keys():
dtype_total_size = total_size / DTYPE_MODIFIER[preset]
inference_requirements = 1.15 * dtype_total_size # 1.15 allows A10G to run the `exact` preset on the 9b model
spare_memory = total_memory - inference_requirements
if inference_requirements < total_memory:
return preset, spare_memory
# if the model does not fit fully in the device, return the last preset ('memory_extreme') which will automatically
# enable CPU offloading so that we can fit any device
return "memory_extreme", 0