def get_multi_weights_col()

in backends/gaudi/server/text_generation_server/layers/gptq/__init__.py [0:0]


    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
        if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
            return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
        try:
            qweight = torch.cat(
                [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
            )
        except RuntimeError:
            raise RuntimeError(
                f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
            )

        scales = torch.cat(
            [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
        )

        self._get_gptq_params(weights)

        qzeros = torch.cat(
            [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
        )

        use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act

        if self.quantize == "gptq" and self.quant_method == "gptq":
            w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
            for w2 in w[1:]:
                torch.testing.assert_close(w2, w[0])
            g_idx = w[0]
        elif self.quantize == "gptq" and self.quant_method == "awq":
            log_once(
                logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
            )
            from text_generation_server.layers.awq.conversion_utils import (
                fast_awq_to_gptq,
            )

            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
            if use_exllama:
                g_idx = None
            else:
                g_idx = (
                    torch.arange(
                        qweight.shape[0] * (32 // self.bits),
                        device=qweight.device,
                    )
                    // self.groupsize
                ).to(dtype=torch.int32)
        else:
            g_idx = None

        return GPTQWeight(
            qweight=qweight,
            qzeros=qzeros,
            scales=scales,
            g_idx=g_idx,
            bits=self.bits,
            groupsize=self.groupsize,
            use_awq_kernel=self.quantize == "awq",
            use_exllama=use_exllama,
        )