in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]
def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]:
if isinstance(batch, PaddedBatch):
input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32)
max_input_lens = 0 # This value will not be used
cu_seqlens = torch.cat(
(input_lens.new_tensor([0]), input_lens.cumsum(-1).int())
)
mask = batch.attention_mask.bool()
bsz, tgt_len = mask.size()
min_val = torch.finfo(self.dtype).min
attn_mask = torch.full(
[bsz, 1, tgt_len, tgt_len],
fill_value=min_val,
device=self.device,
dtype=self.dtype,
)
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, tgt_len)
attn_mask = attn_mask.masked_fill(expanded_mask, 0.0)
elif isinstance(batch, FlashBatch):
cu_seqlens = batch.cu_seqlens
mask = None
attn_mask = None
max_input_lens = batch.max_s
embedding = self.model.forward(
input_ids=batch.input_ids,
token_type_ids=batch.token_type_ids,
position_ids=batch.position_ids,
cu_seqlens=cu_seqlens,
max_s=max_input_lens,
mask=mask,
attn_mask=attn_mask,
)
cpu_results = embedding.view(-1).tolist()
return [
Embedding(
values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size]
)
for i in range(len(batch))
]