in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
query_weight = handle.get_tensor(f"{prefix}.self.query.weight")
query_bias = handle.get_tensor(f"{prefix}.self.query.bias")
key_weight = handle.get_tensor(f"{prefix}.self.key.weight")
key_bias = handle.get_tensor(f"{prefix}.self.key.bias")
value_weight = handle.get_tensor(f"{prefix}.self.value.weight")
value_bias = handle.get_tensor(f"{prefix}.self.value.bias")
self.qkv_weight = (
torch.cat([query_weight, key_weight, value_weight]).T.to(dtype).to(device)
)
self.qkv_bias = (
torch.cat([query_bias, key_bias, value_bias]).to(dtype).to(device)
)
self.dense_weight = (
handle.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
)
self.dense_bias = (
handle.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
)
self.layer_norm = FastLayerNorm(
f"{prefix}.output.LayerNorm", handle, device, dtype, config
)
self.head_size = config.hidden_size // config.num_attention_heads
self.softmax_scale = self.head_size**-0.5
self.num_heads = config.num_attention_heads
self.device = device