backends/python/server/text_embeddings_server/models/pooling.py (30 lines of code) (raw):

from abc import ABC, abstractmethod import torch from opentelemetry import trace from sentence_transformers.models import Pooling from torch import Tensor tracer = trace.get_tracer(__name__) class _Pooling(ABC): @abstractmethod def forward(self, model_output, attention_mask) -> Tensor: pass class DefaultPooling(_Pooling): def __init__(self, hidden_size, pooling_mode) -> None: assert ( pooling_mode != "splade" ), "Splade pooling is not supported for DefaultPooling" self.pooling = Pooling(hidden_size, pooling_mode=pooling_mode) @tracer.start_as_current_span("pooling") def forward(self, model_output, attention_mask) -> Tensor: pooling_features = { "token_embeddings": model_output[0], "attention_mask": attention_mask, } return self.pooling.forward(pooling_features)["sentence_embedding"] class SpladePooling(_Pooling): @tracer.start_as_current_span("pooling") def forward(self, model_output, attention_mask) -> Tensor: # Implement Splade pooling hidden_states = torch.relu(model_output[0]) hidden_states = (1 + hidden_states).log() hidden_states = torch.mul(hidden_states, attention_mask.unsqueeze(-1)) return hidden_states.max(dim=1).values