in optimum/habana/transformers/trainer.py [0:0]
def _get_input_update_settings(model, lazy_mode: Optional[bool] = None) -> Tuple[bool, Dict]:
"""
Determines whether the input settings need to be updated.
Currently (attn_softmax_bf16, use_flash_attention, flash_attention_recompute,
flash_attention_causal_mask) are enabled only for llama, qwen2, starcoder2, gemma, baichuan
and chatglm
lazy_mode for llama, qwen2, starcoder2 and mistral
Args:
model: The model instance for which the input update settings are being evaluated
lazy_mode[Optional[bool]]: Whether to use lazy mode for the model (defaults to `None`)
Returns:
Tuple[bool, Dict]: A flag indicating whether the input settings should be updated.
A dictionary containing the specific input settings that need to be updated, if any
"""
inputs_update: Dict = {}
should_update_inputs = (getattr(model, "generation_config", None) is not None) and (
model.config.model_type in ("llama", "qwen2", "starcoder2", "gemma", "baichuan", "chatglm", "deepseek_v2")
)
if should_update_inputs:
if model.generation_config.attn_softmax_bf16:
inputs_update["attn_softmax_bf16"] = True
if model.generation_config.use_flash_attention:
inputs_update["use_flash_attention"] = True
if model.generation_config.flash_attention_recompute:
inputs_update["flash_attention_recompute"] = True
if model.generation_config.flash_attention_causal_mask:
inputs_update["flash_attention_causal_mask"] = True
should_update_inputs = (
(getattr(model, "generation_config", None) is not None)
and (model.config.model_type in ("llama", "qwen2", "starcoder2", "mistral"))
and (lazy_mode is not None)
)
if should_update_inputs:
if _is_peft_model(model):
forward_method = getattr(model.get_base_model(), "forward")
else:
forward_method = getattr(model, "forward")
signature = inspect.signature(forward_method)
if "lazy_mode" in signature.parameters:
inputs_update["lazy_mode"] = lazy_mode
should_update_inputs: bool = len(inputs_update) > 0
return should_update_inputs, inputs_update