backends/python/server/text_embeddings_server/models/flash_bert.py (318 lines of code) (raw):
import torch
from pathlib import Path
from torch import nn
import torch.nn.functional as F
from typing import Type, List, Union
from safetensors import safe_open
from transformers.activations import ACT2FN
from transformers.models.bert import BertConfig
from opentelemetry import trace
from text_embeddings_server.models import Model
from text_embeddings_server.models.types import FlashBatch, Embedding, PaddedBatch
from text_embeddings_server.utils.flash_attn import attention
from text_embeddings_server.utils.device import use_ipex
tracer = trace.get_tracer(__name__)
def hpu_add_layer_norm(
add: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
epsilon: float,
add_back: bool,
):
if add is not None:
added_tensor = torch.add(add, x, alpha=1.0)
output = F.layer_norm(added_tensor, [x.size(-1)], weight, bias, epsilon)
if add_back:
add.add_(x)
return output
else:
return F.layer_norm(x, [x.size(-1)], weight=weight, bias=bias, eps=epsilon)
class FastLayerNorm:
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device)
self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device)
self.variance_epsilon = config.layer_norm_eps
self.device = device
self.use_ipex = use_ipex()
def forward(self, hidden_states, residual=None):
# Flash attention imports
normed_hidden_states = None
res = None
if self.device.type == "cuda":
import dropout_layer_norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
False,
)
if res is None:
res = hidden_states
elif self.use_ipex:
import intel_extension_for_pytorch as ipex
normed_hidden_states = ipex.llm.functional.add_layer_norm(
residual,
hidden_states,
self.weight,
self.bias,
self.variance_epsilon,
residual is not None,
)
res = residual if residual is not None else hidden_states
elif self.device.type == "hpu":
normed_hidden_states = hpu_add_layer_norm(
residual,
hidden_states,
self.weight,
self.bias,
self.variance_epsilon,
residual is not None,
)
res = residual if residual is not None else hidden_states
return normed_hidden_states, res
class BertEmbeddings:
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.word_embeddings_weight = (
handle.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device)
)
self.token_type_embeddings_weight = (
handle.get_tensor(f"{prefix}.token_type_embeddings.weight")
.to(dtype)
.to(device)
)
if config.position_embedding_type == "absolute":
self.position_embeddings_weight = (
handle.get_tensor(f"{prefix}.position_embeddings.weight")
.to(dtype)
.to(device)
)
else:
raise NotImplementedError(
"FlashBert only supports absolute position embeddings"
)
self.layer_norm = FastLayerNorm(
f"{prefix}.LayerNorm", handle, device, dtype, config
)
def forward(self, input_ids, token_type_ids, position_ids):
inputs_embeds = nn.functional.embedding(input_ids, self.word_embeddings_weight)
token_type_embeds = nn.functional.embedding(
token_type_ids, self.token_type_embeddings_weight
)
position_embeds = nn.functional.embedding(
position_ids, self.position_embeddings_weight
)
inputs_embeds += position_embeds
embeddings, _ = self.layer_norm.forward(inputs_embeds, token_type_embeds)
return embeddings
class BertAttention:
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
query_weight = handle.get_tensor(f"{prefix}.self.query.weight")
query_bias = handle.get_tensor(f"{prefix}.self.query.bias")
key_weight = handle.get_tensor(f"{prefix}.self.key.weight")
key_bias = handle.get_tensor(f"{prefix}.self.key.bias")
value_weight = handle.get_tensor(f"{prefix}.self.value.weight")
value_bias = handle.get_tensor(f"{prefix}.self.value.bias")
self.qkv_weight = (
torch.cat([query_weight, key_weight, value_weight]).T.to(dtype).to(device)
)
self.qkv_bias = (
torch.cat([query_bias, key_bias, value_bias]).to(dtype).to(device)
)
self.dense_weight = (
handle.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
)
self.dense_bias = (
handle.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
)
self.layer_norm = FastLayerNorm(
f"{prefix}.output.LayerNorm", handle, device, dtype, config
)
self.head_size = config.hidden_size // config.num_attention_heads
self.softmax_scale = self.head_size**-0.5
self.num_heads = config.num_attention_heads
self.device = device
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
residual = hidden_states
qkv = F.linear(hidden_states, self.qkv_weight.T, self.qkv_bias)
bs = 1
hidden_dim = hidden_states.size(-1)
is_flat = True
if hidden_states.dim() > 2:
is_flat = False
bs = hidden_states.size(0)
q, k, v = qkv.view(bs, -1, self.num_heads * 3, self.head_size).split(
self.num_heads, dim=2
)
else:
q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(
self.num_heads, dim=1
)
attn_output = torch.empty_like(q)
attention(
q,
k,
v,
attn_output,
cu_seqlens,
max_s,
self.softmax_scale,
attn_mask=attn_mask,
)
hidden_states = torch.addmm(
self.dense_bias,
attn_output.view(-1, self.num_heads * self.head_size),
self.dense_weight,
)
if not is_flat:
hidden_states = hidden_states.view(bs, -1, hidden_dim)
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)
return hidden_states
class BertLayer:
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.attention = BertAttention(
f"{prefix}.attention", handle, device, dtype, config
)
self.intermediate_weight = (
handle.get_tensor(f"{prefix}.intermediate.dense.weight")
.T.to(dtype)
.to(device)
)
self.intermediate_bias = (
handle.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device)
)
act = config.hidden_act
self.intermediate_act_fn = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
)
self.output_weight = (
handle.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
)
self.output_bias = (
handle.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
)
self.layer_norm = FastLayerNorm(
f"{prefix}.output.LayerNorm", handle, device, dtype, config
)
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
hidden_states = self.attention.forward(
hidden_states, cu_seqlens, max_s, attn_mask
)
residual = hidden_states
hidden_states = F.linear(
hidden_states, self.intermediate_weight.T, self.intermediate_bias
)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = F.linear(hidden_states, self.output_weight.T, self.output_bias)
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)
return hidden_states
class BertEncoder:
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.layers = [
BertLayer(f"{prefix}.layer.{i}", handle, device, dtype, config)
for i in range(config.num_hidden_layers)
]
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
for layer in self.layers:
hidden_states = layer.forward(hidden_states, cu_seqlens, max_s, attn_mask)
return hidden_states
class FlashBertModel:
def __init__(self, handle, device, dtype, config: BertConfig):
self.embeddings = BertEmbeddings("embeddings", handle, device, dtype, config)
self.encoder = BertEncoder("encoder", handle, device, dtype, config)
def forward(
self,
input_ids,
token_type_ids,
position_ids,
cu_seqlens,
max_s,
mask=None,
attn_mask=None,
):
embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids)
encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s, attn_mask)
if mask is not None:
outputs = encoder_outputs[mask]
return outputs[cu_seqlens[:-1]]
return encoder_outputs[cu_seqlens[:-1]]
class FlashBert(Model):
def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = False,
):
config = BertConfig.from_pretrained(model_path)
if hasattr(config, "max_seq_length"):
self.max_input_length = config.max_seq_length
else:
self.max_input_length = config.max_position_embeddings
with safe_open(model_path / "model.safetensors", framework="pt") as f:
model = FlashBertModel(f, device, dtype, config)
self.device = device
self.dtype = dtype
self.hidden_size = config.hidden_size
super(FlashBert, self).__init__(model=model, dtype=dtype, device=device)
@property
def batch_type(self) -> Union[FlashBatch, PaddedBatch]:
# for hpu devices, we use PaddedBatch as we do not have real varlen fwd yet
return FlashBatch if self.device.type != "hpu" else PaddedBatch
@tracer.start_as_current_span("embed")
def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]:
if isinstance(batch, PaddedBatch):
input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32)
max_input_lens = 0 # This value will not be used
cu_seqlens = torch.cat(
(input_lens.new_tensor([0]), input_lens.cumsum(-1).int())
)
mask = batch.attention_mask.bool()
bsz, tgt_len = mask.size()
min_val = torch.finfo(self.dtype).min
attn_mask = torch.full(
[bsz, 1, tgt_len, tgt_len],
fill_value=min_val,
device=self.device,
dtype=self.dtype,
)
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, tgt_len)
attn_mask = attn_mask.masked_fill(expanded_mask, 0.0)
elif isinstance(batch, FlashBatch):
cu_seqlens = batch.cu_seqlens
mask = None
attn_mask = None
max_input_lens = batch.max_s
embedding = self.model.forward(
input_ids=batch.input_ids,
token_type_ids=batch.token_type_ids,
position_ids=batch.position_ids,
cu_seqlens=cu_seqlens,
max_s=max_input_lens,
mask=mask,
attn_mask=attn_mask,
)
cpu_results = embedding.view(-1).tolist()
return [
Embedding(
values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size]
)
for i in range(len(batch))
]