def forward()

in server/text_generation_server/layers/fp8.py [0:0]


    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if self.weight_block_size is not None:
            # https://arxiv.org/pdf/2412.19437
            # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
            # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
            # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
            # channels).
            qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
            output = w8a8_block_fp8_matmul(
                qinput,
                self.qweight,
                scale,
                self.scale,
                self.weight_block_size,
                output_dtype=input.dtype,
            )

            if self.bias is not None:
                output = output + self.bias
            return output.to(dtype=input.dtype)
        if CUTLASS_FP8_AVAILABLE:
            # cutlass FP8 supports per-token scales, so get non-scalar scales.
            qinput, scale = fp8_quantize(
                input, scale_upper_bound=self.scale_upper_bound, scalar=False
            )
            return quantization.cutlass_scaled_mm(
                qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
            )

        qinput, scale = fp8_quantize(
            input,
            self.input_scale,
            scale_upper_bound=self.scale_upper_bound,
            scalar=True,
        )

        per_tensor_weights = self.scale.numel() == 1
        per_tensor_activations = scale.numel() == 1

        if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                out_dtype=self.dtype,
                scale_a=scale,
                scale_b=self.scale,
                bias=self.bias,
            )

            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]
        else:
            device_identity = None
            if SYSTEM == "rocm":
                device_identity = self.get_shared_device_identity(self.qweight.device)

            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                scale_a=device_identity,
                scale_b=device_identity,
                out_dtype=torch.float32,
            )
            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]

            output = output * scale * self.scale.t()
            if self.bias is not None:
                output = output + self.bias

            output = output.to(dtype=self.dtype)

        return output