def forward()

in server/text_generation_server/models/custom_modeling/idefics_modeling.py [0:0]


    def forward(self, hidden_states, residual=None):
        if SYSTEM == "ipex":
            import intel_extension_for_pytorch as ipex

            out = ipex.llm.functional.add_rms_norm(
                residual,
                hidden_states,
                self.weight,
                None,
                self.variance_epsilon,
                residual is not None,
            )
            return out
        elif hidden_states.shape[-1] > 8192:
            if residual is not None:
                hidden_states += residual
            residual = hidden_states

            hidden_states = hidden_states.to(torch.float32)
            variance = hidden_states.pow(2).mean(-1, keepdim=True)
            hidden_states = hidden_states * torch.rsqrt(
                variance + self.variance_epsilon
            )

            # convert into half-precision if necessary
            if self.weight.dtype in [torch.float16, torch.bfloat16]:
                hidden_states = hidden_states.to(self.weight.dtype)

            return self.weight * hidden_states
        elif SYSTEM == "cuda":
            # faster post attention rms norm
            unwrap = False
            if len(hidden_states.shape) > 2:
                unwrap = True
                shape = hidden_states.shape
                hidden_states = hidden_states.reshape(-1, shape[-1])

            normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
                hidden_states,
                residual,
                self.weight,
                None,
                None,
                None,
                None,
                None,
                0.0,
                self.variance_epsilon,
                1.0,
                0,
                None,
                False,
                True,  # Activate RMSNorm
            )
            if res is None:
                res = hidden_states

            if unwrap:
                normed_hidden_states = normed_hidden_states.view(*shape)

            return normed_hidden_states
        elif SYSTEM == "rocm":
            # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
            if residual is not None:
                hidden_states += residual
            residual = hidden_states

            unwrap = False
            if len(hidden_states.shape) > 2:
                unwrap = True
                shape = hidden_states.shape
                hidden_states = hidden_states.reshape(-1, shape[-1])

            out = torch.empty_like(hidden_states)
            ops.rms_norm(
                out,
                hidden_states,
                self.weight.data,
                self.variance_epsilon,
            )

            if unwrap:
                out = out.view(*shape)

            return out
        else:
            raise ValueError(
                "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
            )