def get_multi_weights_col()

in backends/gaudi/server/text_generation_server/layers/fp8.py [0:0]


    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
        w = [
            weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
        ]
        shapes = [x.shape for x in w]

        # Concat then send to the device
        w = torch.cat(w, dim=dim).to(weights.device)

        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            if self.weight_block_size is not None:
                scale = [
                    weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
                    for p in prefixes
                ]
                scale = torch.cat(scale, dim=dim)
                scale = scale.to(weights.device)
                return Fp8Weight(
                    weight=w,
                    weight_scale=scale,
                    activation_scale_ub=self.activation_scale_ub,
                    dtype=weights.dtype,
                    weight_block_size=self.weight_block_size,
                )

            scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
                for p, shape in zip(prefixes, shapes)
            ]
            scale = torch.cat(scale, dim=0).reshape(-1)

            logical_widths = [x[0] for x in shapes]
            w, scale = requantize_with_max_scale(
                w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
            )

            input_scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
                for p, shape in zip(prefixes, shapes)
                if weights.has_tensor(f"{p}.input_scale")
            ]
            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
            input_scale = (
                torch.cat(input_scale, dim=0).reshape(-1).max()
                if len(input_scale) != 0
                else None
            )

            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                input_scale=input_scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)