def __init__()

in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]


    def __init__(self, prefix, handle, device, dtype, config: BertConfig):
        self.attention = BertAttention(
            f"{prefix}.attention", handle, device, dtype, config
        )

        self.intermediate_weight = (
            handle.get_tensor(f"{prefix}.intermediate.dense.weight")
            .T.to(dtype)
            .to(device)
        )
        self.intermediate_bias = (
            handle.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device)
        )

        act = config.hidden_act
        self.intermediate_act_fn = (
            ACT2FN[act]
            if "gelu" not in act
            else lambda x: torch.nn.functional.gelu(
                x,
                approximate="tanh"
                if act in ["gelu_fast", "gelu_pytorch_tanh"]
                else "none",
            )
        )

        self.output_weight = (
            handle.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
        )
        self.output_bias = (
            handle.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
        )
        self.layer_norm = FastLayerNorm(
            f"{prefix}.output.LayerNorm", handle, device, dtype, config
        )