def validate_bettertransformer()

in optimum/bettertransformer/models/base.py [0:0]


    def validate_bettertransformer(self):
        r"""
        A wrapper function to validate the `BetterTransformer` implementation. Implements most relevant checks
        that are present in: https://github.com/pytorch/pytorch/blob/0fc7de398636f4b53e6c3fde38b4e48a5ff5b37d/torch/nn/modules/transformer.py#L457-L475
        """
        # Sanity checks
        if self.num_heads is None:
            raise ValueError("Number of heads not set for `BetterTransformer` integration.")

        if self.embed_dim is None:
            raise ValueError("Embedding dimension not set for `BetterTransformer` integration.")

        if self.norm2_eps is None or self.norm1_eps is None:
            raise ValueError("`norm2_eps` and `norm1_eps` not set for `BetterTransformer` integration.")

        # Check positional embedding
        if self.pos_emb_type is not None and self.pos_emb_type != "absolute":
            raise ValueError(
                f"Positional embedding type {self.pos_emb_type} not " "supported for `BetterTransformer` integration"
            )

        # Check norm1 epsilon and norm2 epsilon equality
        if self.norm1_eps != self.norm2_eps:
            raise ValueError("norm1_eps and norm2_eps must be equal for `BetterTransformer` integration.")

        # Check activation function
        if self.act_fn in USE_AT_OWN_RISK_ACTIVATION_FUNCTIONS:
            logger.warning(
                f"Overridding {self.act_fn} activation with gelu. Use the transformed model at your own risk, the output logits could be significantly different."
            )
            self.act_fn = "gelu"
        elif self.act_fn not in SUPPORTED_ACTIVATION_FUNCTIONS:
            raise ValueError(
                f"Activation function {self.act_fn} not supported" " for `BetterTransformer` integration."
            )
        self.use_gelu = (self.act_fn == "gelu") or (self.act_fn == "gelu_new")

        # Check num_head is even
        if self.num_heads % 2 == 1:
            raise ValueError(
                f"Number of heads {self.num_heads} is not supported"
                " for `BetterTransformer` integration."
                f" Number of heads must be even."
            )