optimum/exporters/openvino/stateful.py (176 lines of code) (raw):
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging as log
from typing import List
import numpy as np
from transformers import PretrainedConfig
import openvino as ov
from openvino import opset13
from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version, is_transformers_version
from .utils import MULTI_MODAL_TEXT_GENERATION_MODELS
def model_has_state(ov_model: ov.Model):
if isinstance(ov_model, ov.CompiledModel):
return len(ov_model.query_state()) > 0
# TODO: Provide a better way based on the variables availability, but OV Python API doesn't expose required methods
return len(ov_model.get_sinks()) > 0
def model_has_input_output_name(ov_model: ov.Model, name: str):
"""
Helper function for checking that model has specified input or output name
Parameters:
ov_model (ov.Model): # TODO: Can we derive the dimensions from the model topology?
name (str):
name of input or output
Returns:
True if input or output with requested name exists else False
"""
return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])
def fuse_cache_reorder(
ov_model: ov.Model,
not_kv_inputs: List[str],
key_value_input_names: List[str],
gather_dim: int,
):
"""
Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.
Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
Should be run before make_stateful. Implements optimumum's _reorder_cache
inside the model in the beginning of each iteration.
Gather works along given gather_dim dimension that may vary from model to model.
KV-cache inputs are identified based on names in key_value_input_names.
Append the new beam_idx parameter to not_kv_inputs.
Parameters:
ov_model (`ov.Model`):
openvino model for processing
not_kv_inputs (`List[str]`):
list of input nodes in model that not related to past key values
key_value_input_names (`List[str]`):
list of names for key value input layers
gather_dim (int):
dimension for gathering cache during reorder pass
"""
if model_has_input_output_name(ov_model, "beam_idx"):
raise ValueError("Model already has fused cache")
main_input_name = "input_ids" if model_has_input_output_name(ov_model, "input_ids") else "inputs_embeds"
input_batch = ov_model.input(main_input_name).get_partial_shape()[0]
beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]))
beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted?
ov_model.add_parameters([beam_idx])
not_kv_inputs.append(ov_model.inputs[-1])
# Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
for input_name in key_value_input_names:
parameter_output_port = ov_model.input(input_name)
consumers = parameter_output_port.get_target_inputs()
gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim))
for consumer in consumers:
consumer.replace_source_output(gather.output(0))
ov_model.validate_nodes_and_infer_types()
def build_state_initializer(ov_model: ov.Model, batch_dim: int):
"""
Build initialization ShapeOf Expression for all ReadValue ops
Parameters:
ov_model (ov.Model):
openvino model
batch_dim (int):
index of dimension corresponding to batch size
"""
main_input_name = "input_ids" if model_has_input_output_name(ov_model, "input_ids") else "inputs_embeds"
input_ids = ov_model.input(main_input_name)
batch = opset13.gather(opset13.shape_of(input_ids, output_type="i64"), opset13.constant([0]), opset13.constant(0))
for op in ov_model.get_ops():
if op.get_type_name() == "ReadValue":
dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]
dims[batch_dim] = batch
dims = [opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim for dim in dims]
shape = opset13.concat(dims, axis=0)
broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape)
op.set_arguments([broadcast])
ov_model.validate_nodes_and_infer_types()
def make_stateful(
ov_model: ov.Model,
not_kv_inputs: List[str],
key_value_input_names: List[str],
key_value_output_names: List[str],
batch_dim: int,
num_attention_heads: int,
num_beams_and_batch: int = None,
):
"""
Hides kv-cache inputs and outputs inside the model as variables.
Parameters:
ov_model (ov.Model):
openvino model
not_kv_inputs (`List[str]`):
list of input nodes in model that not related to past key values
key_value_input_names (`List[str]`):
list of names for key value input layers
key_value_output_names (`List[str]`):
list of names for key value input layers
batch_dim (int):
index of batch dimension in key value layers
num_attention_heads (int):
number of attention heads for batch dimension initialization
num_beams_an_batch (int):
precalculated number of beams and batch for shapes initialization
"""
from openvino._offline_transformations import apply_make_stateful_transformation
input_output_map = {}
# TODO: Can we derive the dimensions from the model topology?
if num_beams_and_batch is not None:
# Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue
for input in not_kv_inputs:
shape = input.get_partial_shape()
if shape.rank.get_length() <= 2: # == 1 for beam_index
shape[0] = num_beams_and_batch
input.get_node().set_partial_shape(shape)
else:
log.warning(f"Rank of {input.get_any_name()} input of the model is not 2, batch size is not set")
for kv_name_pair in zip(key_value_input_names, key_value_output_names):
input_output_map[kv_name_pair[0]] = kv_name_pair[1]
if num_beams_and_batch is not None:
input = ov_model.input(kv_name_pair[0])
shape = input.get_partial_shape()
shape[batch_dim] = num_beams_and_batch * num_attention_heads
input.get_node().set_partial_shape(shape)
if num_beams_and_batch is not None:
# Re-validation model if shapes are altered above
ov_model.validate_nodes_and_infer_types()
apply_make_stateful_transformation(ov_model, input_output_map)
if num_beams_and_batch is None:
build_state_initializer(ov_model, batch_dim)
def ensure_stateful_is_available(warn=True):
"""
Check openvino version and raise error if it does not support stateful models
"""
if is_openvino_version("<", "2023.3"):
if warn:
log.warning(
f"Could not create or use stateful model when using old version of openvino=={_openvino_version}. It may result in sub-optimal inference performance."
"Install openvino>=2023.3.0."
)
return False
return True
_ENCODER_DECODER_TASKS_WITH_PAST = (
"automatic-speech-recognition",
"text2text-generation",
)
_DECODER_TASKS_WITH_PAST = ("text-generation",)
def ensure_export_task_support_stateful(task: str):
from optimum.exporters import TasksManager
task = TasksManager.map_from_synonym(task)
is_stateful = (
task.endswith("-with-past")
and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST + _DECODER_TASKS_WITH_PAST
)
return is_stateful
def ensure_model_type_support_stateful(model_type: str):
return model_type.replace("_", "-") in MULTI_MODAL_TEXT_GENERATION_MODELS
def remove_parameters_by_names(model: ov.Model, names: list):
parameters = [model.input(name).get_node() for name in names]
for p in parameters:
model.remove_parameter(p)
def get_input_nodes(node):
return [input.get_node() for input in node.input_values()]
def find_dependent_nodes(model: ov.Model, sources: list):
# Finds all nodes in `model` that are directly or indirectly dependent on at least one node from the list of nodes in `sources`, including `sources`
result = set(sources)
for node in model.get_ordered_ops():
input_nodes = set(get_input_nodes(node))
if input_nodes & result:
result.add(node)
return result
def get_read_value_ops(model: ov.Model):
return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"]
def get_shape_of_ops(model: ov.Model):
return [op for op in model.get_ops() if op.get_type_name() == "ShapeOf"]
def get_consumer_nodes(node):
consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()])
return {input.get_node() for input in consumer_inputs}
def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list):
# Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's
other_inputs = set(model.get_parameters() + get_read_value_ops(model) + get_shape_of_ops(model)) - set(sources)
other_nodes = find_dependent_nodes(model, other_inputs)
source_dependent_nodes = find_dependent_nodes(model, sources)
# TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph
nodes = source_dependent_nodes - other_nodes
edge_nodes = [node for node in nodes if get_consumer_nodes(node) & other_nodes]
return edge_nodes
def insert_state_for_nodes(model: ov.Model, nodes):
# For each output in a given list `nodes` of ov.Node's, insert ReadValue-Assign pair and use the node output as initialization sub-expression
outputs = sum((node.outputs() for node in nodes), [])
for output in outputs:
consumers = output.get_target_inputs()
# FIXME: get_any_name is not reliable as tensor may not have any names
variable_id = output.get_any_name()
read_value = ov.opset13.read_value(output, variable_id)
for consumer in consumers:
consumer.replace_source_output(read_value.output(0))
assign = ov.opset13.assign(read_value, variable_id)
model.add_sinks([assign])
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"):
return patch_stateful_encoder_decoder(config, ov_model)
return patch_stateful_decoder(config, ov_model)
def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model):
"""
Apply stateful transformation to model to hide key values inputs inside model.
Select transformation parameters based on model architecture
Parameters:
config (`PretrainedConfig`):
model pretrained config
ov_model (`ov.Model`):
openvino model
"""
key_value_input_names = [
key_name for key in ov_model.inputs for key_name in key.get_names() if "key_values" in key_name
]
key_value_output_names = [
key_name for key in ov_model.outputs for key_name in key.get_names() if "present" in key_name
]
not_kv_inputs = [
input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())
]
if not key_value_input_names or not key_value_output_names:
return
# By default, batch is the 0-th but chatglm uses 1-st dimension as batch
# TODO: Deduce from a model via ordinal reshape (?) and topology
batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0
fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)
num_attention_heads = (
config.num_attention_heads if (config.model_type == "bloom" and is_transformers_version("<", "4.44")) else 1
)
make_stateful(
ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
)
def patch_stateful_encoder_decoder(config, ov_model):
encoder_key_value_input_names = [
key.get_any_name()
for key in ov_model.inputs
if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names())
]
remove_parameters_by_names(ov_model, encoder_key_value_input_names)
patch_stateful_decoder(config, ov_model)
insert_state_for_nodes(
ov_model,
find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]),
)