backends/python/server/text_embeddings_server/models/jinaBert_model.py (453 lines of code) (raw):

import torch import math from torch import nn import torch.nn.functional as F from pathlib import Path from typing import Type, List, Optional, Union, Tuple from transformers import AutoConfig, PretrainedConfig from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from opentelemetry import trace from safetensors import safe_open from text_embeddings_server.models.pooling import DefaultPooling from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding, Score tracer = trace.get_tracer(__name__) class JinaBertConfig(PretrainedConfig): def __init__( self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None, feed_forward_type="original", emb_pooler=None, **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout self.feed_forward_type = feed_forward_type self.emb_pooler = emb_pooler class JinaBertEmbeddings: """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, handle, device, dtype, config: JinaBertConfig): self.word_embeddings_weight = ( handle.get_tensor(f"embeddings.word_embeddings.weight").to(dtype).to(device) ) self.token_type_embeddings_weight = ( handle.get_tensor(f"embeddings.token_type_embeddings.weight") .to(dtype) .to(device) ) self.layernorm_weight = ( handle.get_tensor(f"embeddings.LayerNorm.weight").to(dtype).to(device) ) self.layernorm_bias = ( handle.get_tensor(f"embeddings.LayerNorm.bias").to(dtype).to(device) ) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.position_embedding_type = getattr( config, "position_embedding_type", "absolute" ) self.config = config def forward( self, input_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: inputs_embeds = F.embedding(input_ids, self.word_embeddings_weight) token_type_embeddings = F.embedding( token_type_ids, self.token_type_embeddings_weight ) embeddings = inputs_embeds + token_type_embeddings if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings embeddings = F.layer_norm( embeddings, self.layernorm_weight.shape, self.layernorm_weight, self.layernorm_bias, eps=self.config.layer_norm_eps, ) embeddings = self.dropout(embeddings) return embeddings class JinaBertSelfAttention: def __init__(self, prefix, handle, device, dtype, config: JinaBertConfig): if config.hidden_size % config.num_attention_heads != 0 and not hasattr( config, "embedding_size" ): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query_weight = ( handle.get_tensor(f"{prefix}.query.weight").to(dtype).to(device) ) self.query_bias = handle.get_tensor(f"{prefix}.query.bias").to(dtype).to(device) self.key_weight = handle.get_tensor(f"{prefix}.key.weight").to(dtype).to(device) self.key_bias = handle.get_tensor(f"{prefix}.key.bias").to(dtype).to(device) self.value_weight = ( handle.get_tensor(f"{prefix}.value.weight").to(dtype).to(device) ) self.value_bias = handle.get_tensor(f"{prefix}.value.bias").to(dtype).to(device) self.layer_norm_q_weight = ( handle.get_tensor(f"{prefix}.layer_norm_q.weight").to(dtype).to(device) ) self.layer_norm_q_bias = ( handle.get_tensor(f"{prefix}.layer_norm_q.bias").to(dtype).to(device) ) self.layer_norm_k_weight = ( handle.get_tensor(f"{prefix}.layer_norm_k.weight").to(dtype).to(device) ) self.layer_norm_k_bias = ( handle.get_tensor(f"{prefix}.layer_norm_k.bias").to(dtype).to(device) ) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, bias: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.Tensor]: q_hidden_states = F.linear(hidden_states, self.query_weight, self.query_bias) mixed_query_layer = F.layer_norm( q_hidden_states, self.layer_norm_q_weight.shape, self.layer_norm_q_weight, self.layer_norm_q_bias, eps=self.config.layer_norm_eps, ) k_hidden_states = F.linear(hidden_states, self.key_weight, self.key_bias) key_layer = self.transpose_for_scores( F.layer_norm( k_hidden_states, self.layer_norm_k_weight.shape, self.layer_norm_k_weight, self.layer_norm_k_bias, eps=self.config.layer_norm_eps, ) ) v_hidden_states = F.linear(hidden_states, self.value_weight, self.value_bias) value_layer = self.transpose_for_scores(v_hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = F.softmax(attention_scores + bias, dim=-1) attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer,) return outputs class JinaBertSelfOutput: def __init__(self, prefix, handle, device, dtype, config): self.config = config self.dense_weight = ( handle.get_tensor(f"{prefix}.dense.weight").to(dtype).to(device) ) self.dense_bias = handle.get_tensor(f"{prefix}.dense.bias").to(dtype).to(device) self.layerNorm_weight = ( handle.get_tensor(f"{prefix}.LayerNorm.weight").to(dtype).to(device) ) self.layerNorm_bias = ( handle.get_tensor(f"{prefix}.LayerNorm.bias").to(dtype).to(device) ) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( self, hidden_states: torch.Tensor, input_tensor: torch.Tensor ) -> torch.Tensor: hidden_states = F.linear(hidden_states, self.dense_weight, self.dense_bias) hidden_states = self.dropout(hidden_states) hidden_states = F.layer_norm( hidden_states + input_tensor, self.layerNorm_weight.shape, self.layerNorm_weight, self.layerNorm_bias, eps=self.config.layer_norm_eps, ) return hidden_states class JinaBertAttention: def __init__(self, prefix, handle, device, dtype, config): self.self = JinaBertSelfAttention( f"{prefix}.self", handle, device, dtype, config ) self.output = JinaBertSelfOutput( f"{prefix}.output", handle, device, dtype, config ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, bias: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self.forward( hidden_states, attention_mask, bias, ) attention_output = self.output.forward(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[ 1: ] # add attentions if we output them return outputs class JinaBertGLUMLP: def __init__(self, prefix, handle, device, dtype, config: JinaBertConfig): self.config = config if config.feed_forward_type == "reglu": self.act = nn.ReLU() elif config.feed_forward_type == "geglu": self.act = nn.GELU() else: raise ValueError( f"feed_forward_type {config.feed_forward_type} not supported" ) self.up_gated_layer_weight = ( handle.get_tensor(f"{prefix}.up_gated_layer.weight").to(dtype).to(device) ) self.down_layer_weight = ( handle.get_tensor(f"{prefix}.down_layer.weight").to(dtype).to(device) ) self.down_layer_bias = ( handle.get_tensor(f"{prefix}.down_layer.bias").to(dtype).to(device) ) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Up with gate hidden_mlp_states = F.linear(hidden_states, self.up_gated_layer_weight, None) up = hidden_mlp_states[:, :, : self.config.intermediate_size] gated = hidden_mlp_states[:, :, self.config.intermediate_size :] hidden_mlp_states = up * self.act(gated) hidden_mlp_states = self.dropout(hidden_mlp_states) # Down return F.linear(hidden_mlp_states, self.down_layer_weight, self.down_layer_bias) class JinaBertLayer: def __init__(self, prefix, handle, device, dtype, config: JinaBertConfig): self.attention = JinaBertAttention( f"{prefix}.attention", handle, device, dtype, config ) self.config = config self.feed_forward_type = config.feed_forward_type self.layer_norm_1_weight = ( handle.get_tensor(f"{prefix}.layer_norm_1.weight").to(dtype).to(device) ) self.layer_norm_1_bias = ( handle.get_tensor(f"{prefix}.layer_norm_1.bias").to(dtype).to(device) ) self.layer_norm_2_weight = ( handle.get_tensor(f"{prefix}.layer_norm_2.weight").to(dtype).to(device) ) self.layer_norm_2_bias = ( handle.get_tensor(f"{prefix}.layer_norm_2.bias").to(dtype).to(device) ) if self.feed_forward_type.endswith("glu"): self.mlp = JinaBertGLUMLP(f"{prefix}.mlp", handle, device, dtype, config) else: raise ValueError( f"feed_forward_type {self.feed_forward_type} not supported" ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, bias: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.Tensor]: # Pre-Norm residual = hidden_states # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attention_outputs = self.attention.forward( hidden_states, attention_mask, bias=bias, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[ 1: ] # add self attentions if we output attention weights residual = F.layer_norm( residual + attention_output, self.layer_norm_1_weight.shape, self.layer_norm_1_weight, self.layer_norm_1_bias, eps=self.config.layer_norm_eps, ) mlp_output = self.mlp.forward(residual) layer_output = F.layer_norm( residual + mlp_output, self.layer_norm_2_weight.shape, self.layer_norm_2_weight, self.layer_norm_2_bias, eps=self.config.layer_norm_eps, ) outputs = (layer_output,) + outputs return outputs class JinaBertEncoder: def __init__(self, handle, device, dtype, config: JinaBertConfig): self.config = config self.layers = [ JinaBertLayer(f"encoder.layer.{i}", handle, device, dtype, config) for i in range(config.num_hidden_layers) ] self.num_attention_heads = config.num_attention_heads def rebuild_alibi_tensor( self, size: int, device: Optional[Union[torch.device, str]] = None ): # Alibi # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1) # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation # of the logits, which makes the math work out *after* applying causal masking. If no causal masking # will be applied, it is necessary to construct the diagonal mask. n_heads = self.num_attention_heads def _get_alibi_head_slopes(n_heads: int) -> List[float]: def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n_heads).is_integer(): return get_slopes_power_of_2( n_heads ) # In the paper, we only train models that have 2^a heads for some a. This function has else: # some good properties that only occur when the input is a power of 2. To maintain that even closest_power_of_2 = ( 2 ** math.floor(math.log2(n_heads)) ) # when the number of heads is not a power of 2, we use this workaround. return ( get_slopes_power_of_2(closest_power_of_2) + _get_alibi_head_slopes(2 * closest_power_of_2)[0::2][ : n_heads - closest_power_of_2 ] ) context_position = torch.arange(size, device=device)[:, None] memory_position = torch.arange(size, device=device)[None, :] relative_position = torch.abs(memory_position - context_position) # [n_heads, max_token_length, max_token_length] relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1) slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device) * -1 alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position # [1, n_heads, max_token_length, max_token_length] alibi = alibi.unsqueeze(0) assert alibi.shape == torch.Size([1, n_heads, size, size]) self._current_alibi_size = size return alibi def forward( self, hidden_states: torch.Tensor, max_len: int, attention_mask: Optional[torch.FloatTensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: # Add alibi matrix to extended_attention_mask bs, seqlen, _ = hidden_states.size() alibi_bias = self.rebuild_alibi_tensor( size=max_len, device=hidden_states.device ).to(hidden_states.dtype) full_alibi_bias = torch.full( (bs, self.num_attention_heads, seqlen, seqlen), fill_value=torch.finfo(hidden_states.dtype).min, dtype=hidden_states.dtype, device=hidden_states.device, ) full_alibi_bias[:, :, :max_len, :max_len] = alibi_bias for i, layer_module in enumerate(self.layers): layer_outputs = layer_module.forward( hidden_states, attention_mask, full_alibi_bias, ) hidden_states = layer_outputs[0] return hidden_states class FlashJinaBertModel: def __init__(self, handle, device, dtype, config: AutoConfig): self.embeddings = JinaBertEmbeddings(handle, device, dtype, config) self.encoder = JinaBertEncoder(handle, device, dtype, config) def forward( self, input_ids, token_type_ids, position_ids, max_len, attn_mask=None, ): embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) encoder_outputs = self.encoder.forward(embeddings, max_len, attn_mask) return encoder_outputs class FlashJinaBert(Model): def __init__( self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str = "mean", trust_remote: bool = True, ): config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote) if hasattr(config, "max_seq_length"): self.max_input_length = config.max_seq_length else: self.max_input_length = config.max_position_embeddings with safe_open(model_path / "model.safetensors", framework="pt") as f: model = FlashJinaBertModel(f, device, dtype, config) self.hidden_size = config.hidden_size self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool) self.device = device self.dtype = dtype self.hidden_size = config.hidden_size super(FlashJinaBert, self).__init__(model=model, dtype=dtype, device=device) @property def batch_type(self) -> Type[PaddedBatch]: return PaddedBatch def mean_pooling( self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor ): input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) @tracer.start_as_current_span("embed") def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs = {"input_ids": batch.input_ids} kwargs["token_type_ids"] = batch.token_type_ids kwargs["position_ids"] = batch.position_ids input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32) max_input_lens = input_lens.max().item() kwargs["max_len"] = max_input_lens outputs = self.model.forward(**kwargs) embedding = self.mean_pooling(outputs, batch.attention_mask) cpu_results = embedding.view(-1).tolist() return [ Embedding( values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] ) for i in range(len(batch)) ] @tracer.start_as_current_span("predict") def predict(self, batch: PaddedBatch) -> List[Score]: pass