local_gemma/modeling_local_gemma_2.py (165 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Union, Dict
import logging
from tqdm import tqdm
import torch
from transformers import QuantoConfig, is_bitsandbytes_available, BitsAndBytesConfig
from transformers.utils import is_quanto_available, is_accelerate_available
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM, GEMMA2_ATTENTION_CLASSES
import transformers.models.gemma2.modeling_gemma2
from .attention import Gemma2FusedAttention
from .utils.config import infer_device, infer_dtype, infer_memory_requirements
logger = logging.getLogger(__name__)
EXACT = {
"attn_implementation": "eager",
}
SPEED = {
"attn_implementation": "eager",
"torch_compile": True,
}
MEMORY = {
"attn_implementation": "eager",
"quantization_config": {
"weights": "int4"
}
}
MEMORY_EXTREME = {
"attn_implementation": "eager",
"device_map": "auto",
"quantization_config": {
"weights": "int4"
}
}
PRESET_MAPPING = {
"auto": None,
"exact": EXACT,
"speed": SPEED,
"memory": MEMORY,
"memory_extreme": MEMORY_EXTREME,
}
transformers.models.gemma2.modeling_gemma2.GEMMA2_ATTENTION_CLASSES = {
**GEMMA2_ATTENTION_CLASSES,
"fused": Gemma2FusedAttention,
}
class LocalGemma2ForCausalLM(Gemma2ForCausalLM):
@staticmethod
def get_preset_kwargs(pretrained_model_name_or_path: str, preset: str, device: str, trust_remote_code: bool = False, token: str = None) -> Dict:
if preset not in PRESET_MAPPING:
raise ValueError(f"Got invalid `preset` {preset}. Ensure `preset` is one of: {list(PRESET_MAPPING.keys())}")
if preset == "auto":
preset, _ = infer_memory_requirements(
pretrained_model_name_or_path, device, trust_remote_code=trust_remote_code, token=token
)
logger.info(f"Detected device {device} and defaulting to {preset} preset.")
preset_kwargs = PRESET_MAPPING[preset]
if preset == "speed" and device != "cuda":
# disable torch compile on non-cuda devices since it's not compatible
preset_kwargs["torch_compile"] = False
if preset in ["memory", "memory_extreme"]:
if device == "cuda" and not is_bitsandbytes_available():
raise ImportError(
f"The {preset} preset on CUDA requires the `bitsandbytes` package. Please install bitsandbytes through: "
"`pip install --upgrade bitsandbytes`."
)
elif device != "cuda" and not is_quanto_available():
raise ImportError(
f"The {preset} preset on {device} requires the `quanto` package. Please install quanto through: "
"`pip install --upgrade quanto`."
)
if preset == "memory_extreme":
if not is_accelerate_available():
raise ImportError(
f"The `memory_extreme` preset requires the `accelerate` package. Please install accelerate through: "
"`pip install --upgrade accelerate`."
)
return preset_kwargs
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
preset: Optional[str] = "auto",
torch_compile: Optional[bool] = None,
*model_args,
config: Optional[Union[Gemma2Config, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
) -> Gemma2ForCausalLM:
device = infer_device(kwargs.pop("device", None))
preset_kwargs = cls.get_preset_kwargs(
pretrained_model_name_or_path,
preset,
device=device,
trust_remote_code=kwargs.get("trust_remote_code"),
token=kwargs.get("token"),
)
preset_kwargs["low_cpu_mem_usage"] = True
torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is None:
torch_dtype = infer_dtype(device)
if torch_dtype == torch.float16:
extra_message = ' and weights' if preset not in ['memory', 'memory_extreme'] else ''
logger.warning(
f"Defaulting to float16 precision for the computations{extra_message}. "
f"This can cause instabilities in generation for larger models, e.g. the 27b checkpoints."
)
preset_kwargs["torch_dtype"] = torch_dtype
preset_torch_compile = preset_kwargs.pop("torch_compile", False)
torch_compile = torch_compile if torch_compile is not None else preset_torch_compile
quantization_config = kwargs.pop("quantization_config", None)
if quantization_config is not None:
preset_kwargs["quantization_config"] = quantization_config
elif preset_kwargs.get("quantization_config"):
if device == "cuda":
preset_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_enable_fp32_cpu_offload=True,
bnb_4bit_compute_dtype=preset_kwargs["torch_dtype"],
)
else:
preset_kwargs["quantization_config"] = QuantoConfig(
weights=preset_kwargs["quantization_config"]["weights"]
)
# give preference to kwargs passed by the user
kwargs_copy = kwargs.copy()
if kwargs is not None:
for key in kwargs_copy:
if key in preset_kwargs:
preset_kwargs[key] = kwargs.pop(key)
model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**preset_kwargs,
**kwargs,
)
if device not in str(model.device) and preset_kwargs.get("device_map", None) is None:
# for consistent behaviour with bitsandbytes, we move the model to the device always
model.to(device, dtype=preset_kwargs["torch_dtype"])
if torch_compile and device != "cuda":
raise ValueError(
"Torch compile is only compatible with cuda devices. Set `torch_compile=False` in `.from_pretrained`"
f"for device {device}."
)
elif torch_compile:
model = fuse_attention_weights(model, device, torch_dtype)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
return model
def fuse_attention_weights(model: LocalGemma2ForCausalLM, device, torch_dtype) -> LocalGemma2ForCausalLM:
for idx, layer in tqdm(enumerate(model.model.layers), desc="Fusing attention weights", total=model.config.num_hidden_layers):
state_dict = layer.self_attn.state_dict()
del layer.self_attn
layer.self_attn = Gemma2FusedAttention(model.config, layer_idx=idx)
# convert un-fused to fused through the pre-register hook
layer.self_attn.load_state_dict(state_dict)
layer.self_attn.to(device, dtype=torch_dtype)
return model