in optimum/exporters/executorch/tasks/causal_lm.py [0:0]
def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportableModule:
"""
Loads a causal language model for text generation and registers it under the task
'text-generation' using Hugging Face's AutoModelForCausalLM.
Args:
model_name_or_path (str):
Model ID on huggingface.co or path on disk to the model repository to export. For example:
`model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`
**kwargs:
Additional configuration options for the model:
- dtype (str, optional):
Data type for model weights (default: "float32").
Options include "float16" and "bfloat16".
- attn_implementation (str, optional):
Attention mechanism implementation (default: "sdpa").
- cache_implementation (str, optional):
Cache management strategy (default: "static").
- max_length (int, optional):
Maximum sequence length for generation (default: 2048).
Returns:
CausalLMExportableModule:
An instance of `CausalLMExportableModule` for exporting and lowering to ExecuTorch.
"""
device = "cpu"
batch_size = 1
dtype = kwargs.get("dtype", "float32")
use_custom_sdpa = kwargs.get("use_custom_sdpa", False)
use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False)
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
cache_implementation = kwargs.get("cache_implementation", "static")
use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa"
max_length = kwargs.get("max_length", 2048)
config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path)
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
# the data-dependent control flow in _longrope_frequency_update. Alternatively, users should rewrite
# that function to avoid the data-dependent control flow.
config.rope_scaling["type"] = "default"
eager_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map=device,
torch_dtype=dtype,
config=config,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_length,
},
),
)
for param in eager_model.parameters():
# Must disable gradient for quantized checkpoint
if isinstance(param, torchao.utils.TorchAOBaseTensor):
param.requires_grad = False
# TODO: Move quantization recipe out for better composability.
# TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
qlinear_config = kwargs.get("qlinear", None)
qembedding_config = kwargs.get("qembedding", None)
if qlinear_config or qembedding_config:
# TODO: Update torchao to use 0.11.0 once released
if parse(torchao.__version__) < parse("0.11.0.dev0"):
raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.")
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass
if qembedding_config:
logging.info("Quantizing embedding layers.")
# TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
embedding_config = IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
)
quantize_(
eager_model,
embedding_config,
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
if qlinear_config:
logging.info("Quantizing linear layers.")
linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
)
quantize_(
eager_model,
linear_config,
)
unwrap_tensor_subclass(eager_model)
return CausalLMExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa)