def get_multi_weights_col()

in server/text_generation_server/layers/compressed_tensors/w8an_fp.py [0:0]


    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
        w = [
            weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
        ]
        shapes = [x.shape for x in w]

        # Concat then send to the device
        w = torch.cat(w, dim=dim).to(weights.device)

        weight_scale = None
        if self.load_weight_scale:
            weight_scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
                for p, shape in zip(prefixes, shapes)
            ]
            weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)

        input_scale = None
        if self.load_input_scale:
            input_scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
                for p, shape in zip(prefixes, shapes)
                if weights.has_tensor(f"{p}.input_scale")
            ]
            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
            input_scale = (
                torch.cat(input_scale, dim=0).reshape(-1).max()
                if len(input_scale) != 0
                else None
            )

        if self.load_weight_scale and SYSTEM == "rocm":
            w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(
                w, weight_scale, input_scale
            )

            if weight_scale.numel() == len(prefixes):
                logical_widths = [x[0] for x in shapes]
                w, weight_scale = requantize_with_max_scale(
                    w, weight_scale.to(weights.device), logical_widths, weights.dtype
                )

        return Fp8Weight(
            weight=w,
            weight_scale=weight_scale,
            input_scale=input_scale,
            dtype=weights.dtype,
            force_w8a16=self.force_w8a16,
        )