in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.word_embeddings_weight = (
handle.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device)
)
self.token_type_embeddings_weight = (
handle.get_tensor(f"{prefix}.token_type_embeddings.weight")
.to(dtype)
.to(device)
)
if config.position_embedding_type == "absolute":
self.position_embeddings_weight = (
handle.get_tensor(f"{prefix}.position_embeddings.weight")
.to(dtype)
.to(device)
)
else:
raise NotImplementedError(
"FlashBert only supports absolute position embeddings"
)
self.layer_norm = FastLayerNorm(
f"{prefix}.LayerNorm", handle, device, dtype, config
)