def _convert_model()

in optimum/habana/accelerate/utils/transformer_engine.py [0:0]


def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
    """
    Recursively converts the linear layer of a model to their `transformers_engine` counterpart.
    """
    from ...transformers.models.llama.modeling_llama import ModuleFusedSDPA

    if not is_fp8_available():
        raise ImportError("Using `convert_model` requires transformer_engine to be installed.")

    minimize_memory = str_to_bool(os.getenv("PT_HPU_FP8_MINIMIZE_MEMORY", "false"))

    for name, module in model.named_children():
        if is_peft_available() and isinstance(module, lora.Linear) and to_transformer_engine and _convert_linear:
            # For lora linear module, convert only base linear layer to fp8 and skip lora-a,
            # lora-b linear layers. Since lora-a, lora-b are small in size, there is not much
            # device performance gain by pushing these in fp8. This way we avoid host overhead
            # associated with using TE for these layers.
            for name, lora_module in module.named_children():
                if name == "base_layer":
                    has_bias = lora_module.bias is not None
                    # Initializing TE linear without weights and biases and shallow copying them from the original module.
                    te_module = te.Linear(
                        lora_module.in_features,
                        lora_module.out_features,
                        bias=has_bias,
                        params_dtype=lora_module.weight.dtype,
                        skip_weight_param_allocation=True,
                        minimize_memory=minimize_memory,
                    )
                    te_module.weight = lora_module.weight

                    if has_bias:
                        te_module.bias = lora_module.bias

                    setattr(module, name, te_module)
        elif isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear:
            has_bias = module.bias is not None
            # Initializing TE linear without weights and biases and shallow copying them from the original module.
            te_module = te.Linear(
                module.in_features,
                module.out_features,
                bias=has_bias,
                params_dtype=module.weight.dtype,
                skip_weight_param_allocation=True,
                minimize_memory=minimize_memory,
            )
            te_module.weight = module.weight

            if has_bias:
                te_module.bias = module.bias

            setattr(model, name, te_module)
        elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
            has_bias = module.bias is not None
            new_module = torch.nn.Linear(
                module.in_features,
                module.out_features,
                bias=has_bias,
                dtype=module.weight.dtype,
                device=module.weight.device,
            )
            new_module.weight.copy_(module.weight)
            if has_bias:
                new_module.bias.copy_(module.bias)

            setattr(model, name, new_module)
        elif isinstance(module, ModuleFusedSDPA) and module.flash_attention_fp8 and to_transformer_engine:
            from habana_frameworks.torch.hpex.experimental.transformer_engine import (
                FusedAttention as TE_FusedAttention,
            )

            class TE_ModuleFusedSDPA(torch.nn.Module):
                def __init__(self):
                    super().__init__()
                    self._hpu_kernel_fsdpa = TE_FusedAttention(
                        scale=module.scale,
                        attention_dropout=module.attention_dropout,
                        enable_recompute=module.enable_recompute,
                    )

                def forward(
                    self,
                    query,
                    key,
                    value,
                    attn_mask,
                    dropout_p,
                    is_causal,
                    scale,
                    softmax_mode,
                    recompute_mode,
                    valid_sequence_lengths,
                    padding_side="left",
                ):
                    return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, softmax_mode)

            setattr(model, name, TE_ModuleFusedSDPA())
        else:
            _convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear)