def _get_input_update_settings()

in optimum/habana/transformers/trainer.py [0:0]


def _get_input_update_settings(model, lazy_mode: Optional[bool] = None) -> Tuple[bool, Dict]:
    """
    Determines whether the input settings need to be updated.

    Currently (attn_softmax_bf16, use_flash_attention, flash_attention_recompute,
    flash_attention_causal_mask) are enabled only for llama, qwen2, starcoder2, gemma, baichuan
    and chatglm

    lazy_mode for llama, qwen2, starcoder2 and mistral

    Args:
        model: The model instance for which the input update settings are being evaluated
        lazy_mode[Optional[bool]]: Whether to use lazy mode for the model (defaults to `None`)

    Returns:
        Tuple[bool, Dict]: A flag indicating whether the input settings should be updated.
        A dictionary containing the specific input settings that need to be updated, if any
    """
    inputs_update: Dict = {}

    should_update_inputs = (getattr(model, "generation_config", None) is not None) and (
        model.config.model_type in ("llama", "qwen2", "starcoder2", "gemma", "baichuan", "chatglm", "deepseek_v2")
    )
    if should_update_inputs:
        if model.generation_config.attn_softmax_bf16:
            inputs_update["attn_softmax_bf16"] = True
        if model.generation_config.use_flash_attention:
            inputs_update["use_flash_attention"] = True
        if model.generation_config.flash_attention_recompute:
            inputs_update["flash_attention_recompute"] = True
        if model.generation_config.flash_attention_causal_mask:
            inputs_update["flash_attention_causal_mask"] = True

    should_update_inputs = (
        (getattr(model, "generation_config", None) is not None)
        and (model.config.model_type in ("llama", "qwen2", "starcoder2", "mistral"))
        and (lazy_mode is not None)
    )
    if should_update_inputs:
        if _is_peft_model(model):
            forward_method = getattr(model.get_base_model(), "forward")
        else:
            forward_method = getattr(model, "forward")
        signature = inspect.signature(forward_method)
        if "lazy_mode" in signature.parameters:
            inputs_update["lazy_mode"] = lazy_mode

    should_update_inputs: bool = len(inputs_update) > 0

    return should_update_inputs, inputs_update