in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.attention = BertAttention(
f"{prefix}.attention", handle, device, dtype, config
)
self.intermediate_weight = (
handle.get_tensor(f"{prefix}.intermediate.dense.weight")
.T.to(dtype)
.to(device)
)
self.intermediate_bias = (
handle.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device)
)
act = config.hidden_act
self.intermediate_act_fn = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
)
self.output_weight = (
handle.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
)
self.output_bias = (
handle.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
)
self.layer_norm = FastLayerNorm(
f"{prefix}.output.LayerNorm", handle, device, dtype, config
)