def __init__()

in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]


    def __init__(self, prefix, handle, device, dtype, config: BertConfig):
        query_weight = handle.get_tensor(f"{prefix}.self.query.weight")
        query_bias = handle.get_tensor(f"{prefix}.self.query.bias")
        key_weight = handle.get_tensor(f"{prefix}.self.key.weight")
        key_bias = handle.get_tensor(f"{prefix}.self.key.bias")
        value_weight = handle.get_tensor(f"{prefix}.self.value.weight")
        value_bias = handle.get_tensor(f"{prefix}.self.value.bias")

        self.qkv_weight = (
            torch.cat([query_weight, key_weight, value_weight]).T.to(dtype).to(device)
        )
        self.qkv_bias = (
            torch.cat([query_bias, key_bias, value_bias]).to(dtype).to(device)
        )

        self.dense_weight = (
            handle.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
        )
        self.dense_bias = (
            handle.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
        )

        self.layer_norm = FastLayerNorm(
            f"{prefix}.output.LayerNorm", handle, device, dtype, config
        )

        self.head_size = config.hidden_size // config.num_attention_heads
        self.softmax_scale = self.head_size**-0.5
        self.num_heads = config.num_attention_heads
        self.device = device