in backends/python/server/text_embeddings_server/models/jinaBert_model.py [0:0]
def __init__(self, prefix, handle, device, dtype, config: JinaBertConfig):
self.attention = JinaBertAttention(
f"{prefix}.attention", handle, device, dtype, config
)
self.config = config
self.feed_forward_type = config.feed_forward_type
self.layer_norm_1_weight = (
handle.get_tensor(f"{prefix}.layer_norm_1.weight").to(dtype).to(device)
)
self.layer_norm_1_bias = (
handle.get_tensor(f"{prefix}.layer_norm_1.bias").to(dtype).to(device)
)
self.layer_norm_2_weight = (
handle.get_tensor(f"{prefix}.layer_norm_2.weight").to(dtype).to(device)
)
self.layer_norm_2_bias = (
handle.get_tensor(f"{prefix}.layer_norm_2.bias").to(dtype).to(device)
)
if self.feed_forward_type.endswith("glu"):
self.mlp = JinaBertGLUMLP(f"{prefix}.mlp", handle, device, dtype, config)
else:
raise ValueError(
f"feed_forward_type {self.feed_forward_type} not supported"
)