def patch_stateful_decoder()

in optimum/exporters/openvino/stateful.py [0:0]


def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model):
    """
    Apply stateful transformation to model to hide key values inputs inside model.
    Select transformation parameters based on model architecture

    Parameters:
        config (`PretrainedConfig`):
            model pretrained config
        ov_model (`ov.Model`):
            openvino model
    """

    key_value_input_names = [
        key_name for key in ov_model.inputs for key_name in key.get_names() if "key_values" in key_name
    ]
    key_value_output_names = [
        key_name for key in ov_model.outputs for key_name in key.get_names() if "present" in key_name
    ]
    not_kv_inputs = [
        input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())
    ]
    if not key_value_input_names or not key_value_output_names:
        return

    # By default, batch is the 0-th but chatglm uses 1-st dimension as batch
    # TODO: Deduce from a model via ordinal reshape (?) and topology
    batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0

    fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)
    num_attention_heads = (
        config.num_attention_heads if (config.model_type == "bloom" and is_transformers_version("<", "4.44")) else 1
    )
    make_stateful(
        ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
    )